Skip to content

Commit

Permalink
Fix MMD, re-implement metric, ensure functional with different sample…
Browse files Browse the repository at this point in the history
… sizes
  • Loading branch information
stefanradev93 committed Nov 8, 2024
1 parent 8c29cdb commit 0602071
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 62 deletions.
67 changes: 67 additions & 0 deletions bayesflow/metrics/functional/kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from keras import ops

from bayesflow.types import Tensor

# hard coded from ops.logspace(-6, 6, 11)
# to avoid pytorch errors/warnings if you want to use MPS
default_scales = ops.convert_to_tensor(
[
1.0000e-06,
1.5849e-05,
2.5119e-04,
3.9811e-03,
6.3096e-02,
1.0000e00,
1.5849e01,
2.5119e02,
3.9811e03,
6.3096e04,
1.0000e06,
]
)


def gaussian(x: Tensor, y: Tensor, scales: Tensor = default_scales) -> Tensor:
"""Computes a mixture of Gaussian radial basis functions (RBFs) between the samples of x and y.
Parameters
----------
x : Tensor of shape (num_draws_x, num_features)
Comprises `num_draws_x` Random draws from the "source" distribution `P`.
y : Tensor of shape (num_draws_y, num_features)
Comprises `num_draws_y` Random draws from the "source" distribution `Q`.
scales : Tensor, optional (default - default_scales)
List which denotes the widths of each of the Gaussians in the mixture.
Returns
-------
kernel_matrix : Tensor of shape (num_draws_x, num_draws_y)
The kernel matrix between pairs from `x ~ P` and `y ~ Q`.
"""
beta = 1.0 / (2.0 * scales[..., None])
dist = x[..., None] - ops.transpose(y)
dist = ops.transpose(ops.norm(dist, ord=2, axis=1))
s = ops.matmul(beta, ops.reshape(dist, newshape=(1, -1)))
return ops.reshape(ops.sum(ops.exp(-s), axis=0), newshape=ops.shape(dist))


def inverse_multiquadratic(x: Tensor, y: Tensor, scales: Tensor = default_scales) -> Tensor:
"""Computes a mixture of inverse multiquadratic RBFs between the samples of x and y.
Parameters
----------
x : Tensor of shape (num_draws_x, num_features)
Comprises `num_draws_x` Random draws from the "source" distribution `P`.
y : Tensor of shape (num_draws_y, num_features)
Comprises `num_draws_y` Random draws from the "source" distribution `Q`.
scales : Tensor, optional (default - default_scales)
List which denotes multiple scales for the IM-RBF kernel mixture.
Returns
-------
kernel_matrix : Tensor of shape (num_draws_x, num_draws_y)
The kernel matrix between pairs from `x ~ P` and `y ~ Q`.
"""
dist = ops.expand_dims(ops.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1), axis=-1)
sigmas = ops.expand_dims(scales, axis=0)
return ops.sum(sigmas / (dist + sigmas), axis=-1)
106 changes: 49 additions & 57 deletions bayesflow/metrics/functional/maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,61 @@
from bayesflow.types import Tensor
from bayesflow.utils import issue_url

# hard coded from keras.ops.logspace(-6, 6, 11)
# to avoid pytorch errors/warnings if you want to use MPS
default_scales = keras.ops.convert_to_tensor(
[
1.0000e-06,
1.5849e-05,
2.5119e-04,
3.9811e-03,
6.3096e-02,
1.0000e00,
1.5849e01,
2.5119e02,
3.9811e03,
6.3096e04,
1.0000e06,
]
)


def gaussian_kernel(x1: Tensor, x2: Tensor, scales: Tensor = default_scales) -> Tensor:
residuals = x1[:, None] - x2[None, :]
residuals = keras.ops.reshape(residuals, keras.ops.shape(residuals)[:2] + (-1,))
norms = keras.ops.norm(residuals, ord=2, axis=2)
exponent = norms[:, :, None] / (2.0 * scales[None, None, :])
return keras.ops.mean(keras.ops.exp(-exponent), axis=2)


def maximum_mean_discrepancy(x1: Tensor, x2: Tensor, kernel: str = "gaussian", **kwargs) -> Tensor:
"""Computes the maximum mean discrepancy between samples x1 and x2.
:param x1: Tensor of shape (n, ...)
:param x2: Tensor of shape (n, ...)
:param kernel: Name of the kernel to use.
Default: 'gaussian'
:param kwargs: Additional keyword arguments to pass to the kernel function.
:return: Tensor of shape (n,)
The (x1)-sample-wise maximum mean discrepancy between samples in x1 and x2.
from .kernels import gaussian, inverse_multiquadratic


def maximum_mean_discrepancy(
x: Tensor, y: Tensor, kernel: str = "inverse_multiquadratic", unbiased: bool = False, **kwargs
) -> Tensor:
"""Computes a mixture of Gaussian radial basis functions (RBFs) between the samples of x and y.
See the original paper below for details and different estimators:
Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012).
A kernel two-sample test. The Journal of Machine Learning Research, 13(1), 723-773.
https://jmlr.csail.mit.edu/papers/v13/gretton12a.html
Parameters
----------
x : Tensor of shape (num_draws_x, num_features)
Comprises `num_draws_x` Random draws from the "source" distribution `P`.
y : Tensor of shape (num_draws_y, num_features)
Comprises `num_draws_y` Random draws from the "source" distribution `Q`.
kernel : str, optional (default - "inverse_multiquadratic")
The (mixture of) kernels to be used for the MMD computation.
unbiased : bool, optional (default - False)
Whether to use the unbiased MMD estimator. Default is False.
Returns
-------
mmd : Tensor of shape (1, )
The biased or unbiased empirical maximum mean discrepancy (MMD) estimator.
"""
if kernel != "gaussian":

if kernel == "gaussian":
kernel_fn = gaussian
elif kernel == "inverse_multiquadratic":
kernel_fn = inverse_multiquadratic
else:
raise ValueError(
"For now, we only support the Gaussian kernel. "
"For now, we only support a gaussian and an inverse_multiquadratic kernel."
f"If you need a different kernel, please open an issue at {issue_url}"
)
else:
kernel_fn = gaussian_kernel

# cannot check first (batch) dimension since it will be unknown at compile time
if keras.ops.shape(x1)[1:] != keras.ops.shape(x2)[1:]:
if keras.ops.shape(x)[1:] != keras.ops.shape(y)[1:]:
raise ValueError(
f"Expected x1 and x2 to live in the same feature space, "
f"but got {keras.ops.shape(x1)[1:]} != {keras.ops.shape(x2)[1:]}."
f"Expected x and y to live in the same feature space, "
f"but got {keras.ops.shape(x)[1:]} != {keras.ops.shape(y)[1:]}."
)

# use flattened versions
x1 = keras.ops.reshape(x1, (keras.ops.shape(x1)[0], -1))
x2 = keras.ops.reshape(x2, (keras.ops.shape(x2)[0], -1))

k1 = keras.ops.mean(kernel_fn(x1, x1, **kwargs), axis=1)
k2 = keras.ops.mean(kernel_fn(x2, x2, **kwargs), axis=1)
k3 = keras.ops.mean(kernel_fn(x1, x2, **kwargs), axis=1)
if unbiased:
m, n = keras.ops.shape(x)[0], keras.ops.shape(y)[0]
xx = (1.0 / (m * (m + 1))) * keras.ops.sum(kernel_fn(x, x, **kwargs))
yy = (1.0 / (n * (n + 1))) * keras.ops.sum(kernel_fn(y, y, **kwargs))
xy = (2.0 / (m * n)) * keras.ops.sum(kernel_fn(x, y, **kwargs))
else:
xx = keras.ops.mean(kernel_fn(x, x, **kwargs))
yy = keras.ops.mean(kernel_fn(y, y, **kwargs))
xy = keras.ops.mean(kernel_fn(x, y, **kwargs))

return k1 + k2 - 2.0 * k3
return xx + yy - 2.0 * xy
22 changes: 17 additions & 5 deletions bayesflow/metrics/maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from functools import partial
import keras


from .functional import maximum_mean_discrepancy


class MaximumMeanDiscrepancy(keras.metrics.MeanMetricWrapper):
def __init__(self, name="maximum_mean_discrepancy", dtype=None, **kwargs):
fn = partial(maximum_mean_discrepancy, **kwargs)
super().__init__(fn, name=name, dtype=dtype)
class MaximumMeanDiscrepancy(keras.Metric):
def __init__(
self,
name: str = "maximum_mean_discrepancy",
kernel: str = "inverse_multiquadratic",
unbiased: bool = False,
**kwargs,
):
super().__init__(name=name, **kwargs)
self.mmd = self.add_variable(shape=(), initializer="zeros", name="mmd")
self.mmd_fn = partial(maximum_mean_discrepancy, kernel=kernel, unbiased=unbiased)

def update_state(self, x, y):
self.mmd.assign(keras.ops.cast(self.mmd_fn(x, y), self.dtype))

def result(self):
return self.mmd

0 comments on commit 0602071

Please sign in to comment.