Skip to content

Commit

Permalink
Merge pull request #170 from Chase-Grajeda/deepset-refactor
Browse files Browse the repository at this point in the history
Exposed Deepset
  • Loading branch information
stefanradev93 authored Jun 4, 2024
2 parents 142d61f + 969155f commit e59f077
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 0 deletions.
1 change: 1 addition & 0 deletions bayesflow/experimental/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .flow_matching import FlowMatching
from .resnet import ResNet
from .set_transformer import SetTransformer
from .deep_set.deep_set import DeepSet
97 changes: 97 additions & 0 deletions bayesflow/experimental/networks/deep_set/deep_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from .invariant_module import InvariantModule
from .equivariant_module import EquivariantModule
from bayesflow import default_settings as defaults
from keras.api.layers import Dense
from keras import Sequential
import keras

class DeepSet(keras.Model):
def __init__(
self,
summary_dim: int = 10,
num_dense_s1: int = 2,
num_dense_s2: int = 2,
num_dense_s3: int = 3,
num_equiv: int = 2,
dense_s1_args=None,
dense_s2_args=None,
dense_s3_args=None,
pooling_fun: str = "mean",
**kwargs
):
"""Creates a stack of 'num_equiv' equivariant layers followed by a final invariant layer.
Parameters
----------
summary_dim : int, optional, default: 10
The number of learned summary statistics.
num_dense_s1 : int, optional, default: 2
The number of dense layers in the inner function of a deep set.
num_dense_s2 : int, optional, default: 2
The number of dense layers in the outer function of a deep set.
num_dense_s3 : int, optional, default: 2
The number of dense layers in an equivariant layer.
num_equiv : int, optional, default: 2
The number of equivariant layers in the network.
dense_s1_args : dict or None, optional, default: None
The arguments for the dense layers of s1 (inner, pre-pooling function). If `None`,
defaults will be used (see `default_settings`). Otherwise, all arguments for a
tf.keras.layers.Dense layer are supported.
dense_s2_args : dict or None, optional, default: None
The arguments for the dense layers of s2 (outer, post-pooling function). If `None`,
defaults will be used (see `default_settings`). Otherwise, all arguments for a
tf.keras.layers.Dense layer are supported.
dense_s3_args : dict or None, optional, default: None
The arguments for the dense layers of s3 (equivariant function). If `None`,
defaults will be used (see `default_settings`). Otherwise, all arguments for a
tf.keras.layers.Dense layer are supported.
pooling_fun : str of callable, optional, default: 'mean'
If string argument provided, should be one in ['mean', 'max']. In addition, ac actual
neural network can be passed for learnable pooling.
**kwargs : dict, optional, default: {}
Optional keyword arguments passed to the __init__() method of tf.keras.Model.
"""

super().__init__(**kwargs)

# Prepare settings dictionary
settings = dict(
num_dense_s1=num_dense_s1,
num_dense_s2=num_dense_s2,
num_dense_s3=num_dense_s3,
dense_s1_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s1_args is None else dense_s1_args,
dense_s2_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s2_args is None else dense_s2_args,
dense_s3_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s3_args is None else dense_s3_args,
pooling_fun=pooling_fun,
)

# Create equivariant layers and final invariant layer
self.equiv_layers = Sequential([EquivariantModule(settings) for _ in range(num_equiv)])
self.inv = InvariantModule(settings)

# Output layer to output "summary_dim" learned summary statistics
self.out_layer = Dense(summary_dim, activation="linear")
self.summary_dim = summary_dim

def call(self, x, **kwargs):
"""Performs the forward pass of a learnable deep invariant transformation consisting of
a sequence of equivariant transforms followed by an invariant transform.
Parameters
----------
x : tf.Tensor
Input of shape (batch_size, n_obs, data_dim)
Returns
-------
out : tf.Tensor
Output of shape (batch_size, out_dim)
"""

# Pass through series of augmented equivariant transforms
out_equiv = self.equiv_layers(x, **kwargs)

# Pass through final invariant layer
out = self.out_layer(self.inv(out_equiv, **kwargs), **kwargs)

return out
61 changes: 61 additions & 0 deletions bayesflow/experimental/networks/deep_set/equivariant_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from .invariant_module import InvariantModule
from keras import Sequential
from keras.api.layers import Dense
import keras

class EquivariantModule(keras.Model):
"""Implements an equivariant module performing an equivariant transform.
For details and justification, see:
[1] Bloem-Reddy, B., & Teh, Y. W. (2020). Probabilistic Symmetries and Invariant Neural Networks.
J. Mach. Learn. Res., 21, 90-1. https://www.jmlr.org/papers/volume21/19-322/19-322.pdf
"""

def __init__(self, settings, **kwargs):
"""Creates an equivariant module according to [1] which combines equivariant transforms
with nested invariant transforms, thereby enabling interactions between set members.
Parameters
----------
settings : dict
A dictionary holding the configuration settings for the module.
**kwargs : dict, optional, default: {}
Optional keyword arguments passed to the ``tf.keras.Model`` constructor.
"""

super().__init__(**kwargs)

self.invariant_module = InvariantModule(settings)
self.s3 = Sequential([Dense(**settings["dense_s3_args"]) for _ in range(settings["num_dense_s3"])])

def call(self, x, **kwargs):
"""Performs the forward pass of a learnable equivariant transform.
Parameters
----------
x : tf.Tensor
Input of shape (batch_size, ..., x_dim)
Returns
-------
out : tf.Tensor
Output of shape (batch_size, ..., equiv_dim)
"""

# Store shape of x, will be (batch_size, ..., some_dim)
shape = keras.ops.shape(x)

# Example: Output dim is (batch_size, inv_dim) - > (batch_size, N, inv_dim)
out_inv = self.invariant_module(x, **kwargs)
out_inv = keras.ops.expand_dims(out_inv, -2)
tiler = [1] * len(shape)
tiler[-2] = shape[-2]
out_inv_rep = keras.ops.tile(out_inv, tiler)

# Concatenate each x with the repeated invariant embedding
out_c = keras.ops.concatenate([x, out_inv_rep], axis=-1)

# Pass through equivariant func
out = self.s3(out_c, **kwargs)
return out
63 changes: 63 additions & 0 deletions bayesflow/experimental/networks/deep_set/invariant_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import keras
from keras import Sequential
from keras.api.layers import Dense
from bayesflow.exceptions import ConfigurationError
from functools import partial
import tensorflow as tf

class InvariantModule(keras.Model):
"""Implements an invariant module performing a permutation-invariant transform.
For details and rationale, see:
[1] Bloem-Reddy, B., & Teh, Y. W. (2020). Probabilistic Symmetries and Invariant Neural Networks.
J. Mach. Learn. Res., 21, 90-1. https://www.jmlr.org/papers/volume21/19-322/19-322.pdf
"""

def __init__(self, settings, **kwargs):
"""Creates an invariant module according to [1] which represents a learnable permutation-invariant
function with an option for learnable pooling.
Parameters
----------
settings : dict
A dictionary holding the configuration settings for the module.
**kwargs : dict, optional, default: {}
Optional keyword arguments passed to the `tf.keras.Model` constructor.
"""

super().__init__(**kwargs)

# Create internal functions
self.s1 = Sequential([Dense(**settings["dense_s1_args"]) for _ in range(settings["num_dense_s1"])])
self.s2 = Sequential([Dense(**settings["dense_s2_args"]) for _ in range(settings["num_dense_s2"])])

# Pick pooling function
if settings["pooling_fun"] == "mean":
pooling_fun = partial(tf.reduce_mean, axis=-2)
elif settings["pooling_fun"] == "max":
pooling_fun = partial(tf.reduce_max, axis=-2)
else:
if callable(settings["pooling_fun"]):
pooling_fun = settings["pooling_fun"]
else:
raise ConfigurationError("pooling_fun argument not understood!")
self.pooler = pooling_fun

def call(self, x, **kwargs):
"""Performs the forward pass of a learnable invariant transform.
Parameters
----------
x : tf.Tensor
Input of shape (batch_size,..., x_dim)
Returns
-------
out : tf.Tensor
Output of shape (batch_size,..., out_dim)
"""

x_reduced = self.pooler(self.s1(x, **kwargs))
out = self.s2(x_reduced, **kwargs)
return out

0 comments on commit e59f077

Please sign in to comment.