Skip to content

Commit

Permalink
Merge pull request #282 from astro-informatics/normal_test_var
Browse files Browse the repository at this point in the history
Tests checking normalization of flow models
  • Loading branch information
alicjapolanska authored Feb 19, 2024
2 parents 5fa73be + 4ff4e8d commit 6f674a7
Showing 1 changed file with 63 additions and 10 deletions.
73 changes: 63 additions & 10 deletions tests/test_flow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
import harmonic as hm

real_nvp_2D = md.RealNVPModel(2, standardize=True)
spline_4D = md.RQSplineModel(4, standardize=True)
spline_4D = md.RQSplineModel(4, n_layers=2, n_bins=64, standardize=True)
spline_3D = md.RQSplineModel(3, n_layers=2, n_bins=64, standardize=False)

model_classes = [md.RealNVPModel, md.RQSplineModel]

models_to_test = [real_nvp_2D, spline_4D]
models_to_test1 = [real_nvp_2D, spline_4D, spline_3D]
gaussian_var = [0.1, 0.5, 1.0, 10.0, 20.0]

# Make models for serialization tests
# NVP params
Expand Down Expand Up @@ -54,24 +57,25 @@
models_serialization = [real_NVP_serialization, spline_serialization]


def standard_nd_gaussian_pdf(x):
def standard_nd_gaussian_pdf(x, var=1.0):
"""
Calculate the probability density function (PDF) of an n-dimensional Gaussian
distribution with zero mean and unit covariance.
distribution with zero mean and diagonal covariance with entries var.
Parameters:
- x: Input vector of length n.
- var: Gaussian variance
Returns:
- pdf: log PDF value at input vector x.
"""
n = len(x)

# The normalizing constant (coefficient)
C = -jnp.log(2 * jnp.pi) * n / 2
C = -jnp.log(2 * jnp.pi * var) * n / 2

# Calculate the Mahalanobis distance
mahalanobis_dist = jnp.dot(x, x)
mahalanobis_dist = jnp.dot(x, x) / var

# Calculate the PDF value
pdf = C - 0.5 * mahalanobis_dist
Expand Down Expand Up @@ -148,6 +152,60 @@ def test_flow_is_fitted(model):
assert model.is_fitted() == True


@pytest.mark.parametrize("model", models_to_test1)
@pytest.mark.parametrize("var", gaussian_var)
def test_flows_normalization(model, var):
# Define the number of dimensions and the mean of the Gaussian
ndim = model.ndim
num_samples = 10000

if isinstance(model, md.RealNVPModel):
epochs = 100
elif isinstance(model, md.RQSplineModel):
epochs = 30

# Initialize a PRNG key (you can use any valid key)
key = jax.random.PRNGKey(0)
mean = jnp.zeros(ndim)
cov = jnp.eye(ndim) * var

# Generate random samples from the Gaussian distribution
samples = jax.random.multivariate_normal(key, mean, cov, shape=(num_samples,))

model.fit(samples, epochs=epochs, verbose=True)
model.temperature = 1.0

# MC integral of the flow
num_samples_int = 50000
shape = (num_samples_int, ndim)
# Draw samples from uniform distribution -3 to 3 standard deviations away from mean
minval = -3 * var**0.5
maxval = 3 * var**0.5
uniform_samples = jax.random.uniform(
jax.random.PRNGKey(0), shape=shape, minval=minval, maxval=maxval
)
V = (maxval - minval) ** ndim
vals = jnp.exp(model.predict(uniform_samples))
integral = jnp.mean(vals) * V
assert integral == pytest.approx(1.0, rel=0.1), (
"Flow normalization constant is " + str(integral) + "not 1"
)

model.temperature = 0.5
vals = jnp.exp(model.predict(uniform_samples))
integral = jnp.mean(vals) * V
assert integral == pytest.approx(1.0, rel=0.1), (
"Flow with T=0.5 normalization constant is " + str(integral) + " not 1"
)

model.temperature = 0.1
vals = jnp.exp(model.predict(uniform_samples))
integral = jnp.mean(vals) * V
assert integral == pytest.approx(1.0, rel=0.1), (
"Flow with T=0.1 normalization constant is " + str(integral) + "not 1"
)


@pytest.mark.parametrize("model", models_to_test)
def test_flows_gaussian(model):
# Define the number of dimensions and the mean of the Gaussian
Expand Down Expand Up @@ -175,11 +233,6 @@ def test_flows_gaussian(model):
sample_var = jnp.var(flow_samples, axis=0)
sample_mean = jnp.mean(flow_samples, axis=0)

test = jnp.ones(ndim) * 0.2
assert jnp.exp(model.predict(test)) == pytest.approx(
jnp.exp(standard_nd_gaussian_pdf(test)), rel=0.1
), "Flow probability density not in agreement with analytical value"

for i in range(ndim):
assert sample_mean[i] == pytest.approx(0.0, abs=0.15), (
"Sample mean in dimension " + str(i) + " is " + str(sample_mean[i])
Expand Down

0 comments on commit 6f674a7

Please sign in to comment.