Skip to content

Commit

Permalink
Add more tests for flows.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Oct 12, 2023
1 parent 2ed2cda commit 1579a31
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions tests/test_flow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import jax.numpy as jnp
import jax


def test_RealNVP_constructor():

with pytest.raises(ValueError):
Expand Down Expand Up @@ -96,54 +95,68 @@ def test_RQSpline_constructor():



def test_RealNVP_temperature():
def test_RealNVP_gaussian():

# Define the number of dimensions and the mean of the Gaussian
ndim = 2
num_samples = 100
num_samples = 10000
epochs = 50
# Initialize a PRNG key (you can use any valid key)
key = jax.random.PRNGKey(0)

# Generate random samples from the 2D Gaussian distribution
samples = jax.random.normal(key, shape=(num_samples, ndim))

RealNVP = model_nf.RealNVPModel(ndim)
RealNVP.fit(samples)
RealNVP = model_nf.RealNVPModel(ndim, standardize=True)
RealNVP.fit(samples, epochs=epochs)

nsamples = 100
nsamples = 10000
RealNVP.temperature = 1.
flow_samples = RealNVP.sample(nsamples)
sample_var = jnp.var(flow_samples, axis = 0)
RealNVP.temperature = 0.8
flow_samples = RealNVP.sample(nsamples)
sample_var_concentrated = jnp.var(flow_samples, axis = 0)
flow_samples_concentrated = RealNVP.sample(nsamples)
sample_var_concentrated = jnp.var(flow_samples_concentrated, axis = 0)

for i in range(ndim):
assert sample_var[i] > sample_var_concentrated[i], "Reducing temperature increases variance in dimension " + str(i)

sample_mean = jnp.mean(flow_samples, axis=0)

for i in range(ndim):
assert sample_mean[i] == pytest.approx(0.0, abs = 0.1), "Sample mean in dimension " + str(i) + " is " + str(sample_mean[i])
assert sample_var[i] == pytest.approx(1.0, abs = 0.1), "Sample variance in dimension " + str(i) + " is " + str(sample_var[i])



def test_RQSpline_temperature():
def test_RQSpline_gaussian():

# Define the number of dimensions and the mean of the Gaussian
ndim = 2
num_samples = 100
num_samples = 10000
epochs = 20
# Initialize a PRNG key (you can use any valid key)
key = jax.random.PRNGKey(0)

# Generate random samples from the 2D Gaussian distribution
samples = jax.random.normal(key, shape=(num_samples, ndim))

spline = model_nf.RQSplineFlow(ndim)
spline.fit(samples)
spline = model_nf.RQSplineFlow(ndim, standardize=True)
spline.fit(samples, epochs=epochs)

nsamples = 100
nsamples = 10000
spline.temperature = 1.
flow_samples = spline.sample(nsamples)
sample_var = jnp.var(flow_samples, axis = 0)
spline.temperature = 0.8
flow_samples = spline.sample(nsamples)
sample_var_concentrated = jnp.var(flow_samples, axis = 0)
flow_samples_concentrated = spline.sample(nsamples)
sample_var_concentrated = jnp.var(flow_samples_concentrated, axis = 0)

for i in range(ndim):
assert sample_var[i] > sample_var_concentrated[i], "Reducing temperature increases variance in dimension " + str(i)

sample_mean = jnp.mean(flow_samples, axis=0)

for i in range(ndim):
assert sample_var[i] > sample_var_concentrated[i], "Reducing temperature increases variance in dimension " + str(i)
assert sample_mean[i] == pytest.approx(0.0, abs = 0.1), "Sample mean in dimension " + str(i) + " is " + str(sample_mean[i])
assert sample_var[i] == pytest.approx(1.0, abs = 0.1), "Sample variance in dimension " + str(i) + " is " + str(sample_var[i])

0 comments on commit 1579a31

Please sign in to comment.