-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #170 from Chase-Grajeda/deepset-refactor
Exposed Deepset
- Loading branch information
Showing
4 changed files
with
222 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from .invariant_module import InvariantModule | ||
from .equivariant_module import EquivariantModule | ||
from bayesflow import default_settings as defaults | ||
from keras.api.layers import Dense | ||
from keras import Sequential | ||
import keras | ||
|
||
class DeepSet(keras.Model): | ||
def __init__( | ||
self, | ||
summary_dim: int = 10, | ||
num_dense_s1: int = 2, | ||
num_dense_s2: int = 2, | ||
num_dense_s3: int = 3, | ||
num_equiv: int = 2, | ||
dense_s1_args=None, | ||
dense_s2_args=None, | ||
dense_s3_args=None, | ||
pooling_fun: str = "mean", | ||
**kwargs | ||
): | ||
"""Creates a stack of 'num_equiv' equivariant layers followed by a final invariant layer. | ||
Parameters | ||
---------- | ||
summary_dim : int, optional, default: 10 | ||
The number of learned summary statistics. | ||
num_dense_s1 : int, optional, default: 2 | ||
The number of dense layers in the inner function of a deep set. | ||
num_dense_s2 : int, optional, default: 2 | ||
The number of dense layers in the outer function of a deep set. | ||
num_dense_s3 : int, optional, default: 2 | ||
The number of dense layers in an equivariant layer. | ||
num_equiv : int, optional, default: 2 | ||
The number of equivariant layers in the network. | ||
dense_s1_args : dict or None, optional, default: None | ||
The arguments for the dense layers of s1 (inner, pre-pooling function). If `None`, | ||
defaults will be used (see `default_settings`). Otherwise, all arguments for a | ||
tf.keras.layers.Dense layer are supported. | ||
dense_s2_args : dict or None, optional, default: None | ||
The arguments for the dense layers of s2 (outer, post-pooling function). If `None`, | ||
defaults will be used (see `default_settings`). Otherwise, all arguments for a | ||
tf.keras.layers.Dense layer are supported. | ||
dense_s3_args : dict or None, optional, default: None | ||
The arguments for the dense layers of s3 (equivariant function). If `None`, | ||
defaults will be used (see `default_settings`). Otherwise, all arguments for a | ||
tf.keras.layers.Dense layer are supported. | ||
pooling_fun : str of callable, optional, default: 'mean' | ||
If string argument provided, should be one in ['mean', 'max']. In addition, ac actual | ||
neural network can be passed for learnable pooling. | ||
**kwargs : dict, optional, default: {} | ||
Optional keyword arguments passed to the __init__() method of tf.keras.Model. | ||
""" | ||
|
||
super().__init__(**kwargs) | ||
|
||
# Prepare settings dictionary | ||
settings = dict( | ||
num_dense_s1=num_dense_s1, | ||
num_dense_s2=num_dense_s2, | ||
num_dense_s3=num_dense_s3, | ||
dense_s1_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s1_args is None else dense_s1_args, | ||
dense_s2_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s2_args is None else dense_s2_args, | ||
dense_s3_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s3_args is None else dense_s3_args, | ||
pooling_fun=pooling_fun, | ||
) | ||
|
||
# Create equivariant layers and final invariant layer | ||
self.equiv_layers = Sequential([EquivariantModule(settings) for _ in range(num_equiv)]) | ||
self.inv = InvariantModule(settings) | ||
|
||
# Output layer to output "summary_dim" learned summary statistics | ||
self.out_layer = Dense(summary_dim, activation="linear") | ||
self.summary_dim = summary_dim | ||
|
||
def call(self, x, **kwargs): | ||
"""Performs the forward pass of a learnable deep invariant transformation consisting of | ||
a sequence of equivariant transforms followed by an invariant transform. | ||
Parameters | ||
---------- | ||
x : tf.Tensor | ||
Input of shape (batch_size, n_obs, data_dim) | ||
Returns | ||
------- | ||
out : tf.Tensor | ||
Output of shape (batch_size, out_dim) | ||
""" | ||
|
||
# Pass through series of augmented equivariant transforms | ||
out_equiv = self.equiv_layers(x, **kwargs) | ||
|
||
# Pass through final invariant layer | ||
out = self.out_layer(self.inv(out_equiv, **kwargs), **kwargs) | ||
|
||
return out |
61 changes: 61 additions & 0 deletions
61
bayesflow/experimental/networks/deep_set/equivariant_module.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from .invariant_module import InvariantModule | ||
from keras import Sequential | ||
from keras.api.layers import Dense | ||
import keras | ||
|
||
class EquivariantModule(keras.Model): | ||
"""Implements an equivariant module performing an equivariant transform. | ||
For details and justification, see: | ||
[1] Bloem-Reddy, B., & Teh, Y. W. (2020). Probabilistic Symmetries and Invariant Neural Networks. | ||
J. Mach. Learn. Res., 21, 90-1. https://www.jmlr.org/papers/volume21/19-322/19-322.pdf | ||
""" | ||
|
||
def __init__(self, settings, **kwargs): | ||
"""Creates an equivariant module according to [1] which combines equivariant transforms | ||
with nested invariant transforms, thereby enabling interactions between set members. | ||
Parameters | ||
---------- | ||
settings : dict | ||
A dictionary holding the configuration settings for the module. | ||
**kwargs : dict, optional, default: {} | ||
Optional keyword arguments passed to the ``tf.keras.Model`` constructor. | ||
""" | ||
|
||
super().__init__(**kwargs) | ||
|
||
self.invariant_module = InvariantModule(settings) | ||
self.s3 = Sequential([Dense(**settings["dense_s3_args"]) for _ in range(settings["num_dense_s3"])]) | ||
|
||
def call(self, x, **kwargs): | ||
"""Performs the forward pass of a learnable equivariant transform. | ||
Parameters | ||
---------- | ||
x : tf.Tensor | ||
Input of shape (batch_size, ..., x_dim) | ||
Returns | ||
------- | ||
out : tf.Tensor | ||
Output of shape (batch_size, ..., equiv_dim) | ||
""" | ||
|
||
# Store shape of x, will be (batch_size, ..., some_dim) | ||
shape = keras.ops.shape(x) | ||
|
||
# Example: Output dim is (batch_size, inv_dim) - > (batch_size, N, inv_dim) | ||
out_inv = self.invariant_module(x, **kwargs) | ||
out_inv = keras.ops.expand_dims(out_inv, -2) | ||
tiler = [1] * len(shape) | ||
tiler[-2] = shape[-2] | ||
out_inv_rep = keras.ops.tile(out_inv, tiler) | ||
|
||
# Concatenate each x with the repeated invariant embedding | ||
out_c = keras.ops.concatenate([x, out_inv_rep], axis=-1) | ||
|
||
# Pass through equivariant func | ||
out = self.s3(out_c, **kwargs) | ||
return out |
63 changes: 63 additions & 0 deletions
63
bayesflow/experimental/networks/deep_set/invariant_module.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import keras | ||
from keras import Sequential | ||
from keras.api.layers import Dense | ||
from bayesflow.exceptions import ConfigurationError | ||
from functools import partial | ||
import tensorflow as tf | ||
|
||
class InvariantModule(keras.Model): | ||
"""Implements an invariant module performing a permutation-invariant transform. | ||
For details and rationale, see: | ||
[1] Bloem-Reddy, B., & Teh, Y. W. (2020). Probabilistic Symmetries and Invariant Neural Networks. | ||
J. Mach. Learn. Res., 21, 90-1. https://www.jmlr.org/papers/volume21/19-322/19-322.pdf | ||
""" | ||
|
||
def __init__(self, settings, **kwargs): | ||
"""Creates an invariant module according to [1] which represents a learnable permutation-invariant | ||
function with an option for learnable pooling. | ||
Parameters | ||
---------- | ||
settings : dict | ||
A dictionary holding the configuration settings for the module. | ||
**kwargs : dict, optional, default: {} | ||
Optional keyword arguments passed to the `tf.keras.Model` constructor. | ||
""" | ||
|
||
super().__init__(**kwargs) | ||
|
||
# Create internal functions | ||
self.s1 = Sequential([Dense(**settings["dense_s1_args"]) for _ in range(settings["num_dense_s1"])]) | ||
self.s2 = Sequential([Dense(**settings["dense_s2_args"]) for _ in range(settings["num_dense_s2"])]) | ||
|
||
# Pick pooling function | ||
if settings["pooling_fun"] == "mean": | ||
pooling_fun = partial(tf.reduce_mean, axis=-2) | ||
elif settings["pooling_fun"] == "max": | ||
pooling_fun = partial(tf.reduce_max, axis=-2) | ||
else: | ||
if callable(settings["pooling_fun"]): | ||
pooling_fun = settings["pooling_fun"] | ||
else: | ||
raise ConfigurationError("pooling_fun argument not understood!") | ||
self.pooler = pooling_fun | ||
|
||
def call(self, x, **kwargs): | ||
"""Performs the forward pass of a learnable invariant transform. | ||
Parameters | ||
---------- | ||
x : tf.Tensor | ||
Input of shape (batch_size,..., x_dim) | ||
Returns | ||
------- | ||
out : tf.Tensor | ||
Output of shape (batch_size,..., out_dim) | ||
""" | ||
|
||
x_reduced = self.pooler(self.s1(x, **kwargs)) | ||
out = self.s2(x_reduced, **kwargs) | ||
return out |