Skip to content

Commit

Permalink
Add serialization tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Oct 13, 2023
1 parent 58b0898 commit 545c7c3
Showing 1 changed file with 94 additions and 22 deletions.
116 changes: 94 additions & 22 deletions tests/test_flow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,30 @@
import jax
import harmonic as hm

def standard_nd_gaussian_pdf(x):
"""
Calculate the probability density function (PDF) of an n-dimensional Gaussian
distribution with zero mean and unit covariance.
Parameters:
- x: Input vector of length n.
Returns:
- pdf: log PDF value at input vector x.
"""
n = len(x)

# The normalizing constant (coefficient)
C = -jnp.log(2 * jnp.pi)*n/2

# Calculate the Mahalanobis distance
mahalanobis_dist = jnp.dot(x, x)

# Calculate the PDF value
pdf = C - 0.5 * mahalanobis_dist

return pdf

def test_RealNVP_constructor():

with pytest.raises(ValueError):
Expand Down Expand Up @@ -45,7 +69,7 @@ def test_RealNVP_constructor():

assert RealNVP.is_fitted() == False
training_samples = jnp.zeros((12,ndim))
RealNVP.fit(training_samples)
RealNVP.fit(training_samples, verbose=True)
assert RealNVP.is_fitted() == True


Expand Down Expand Up @@ -170,27 +194,75 @@ def test_RQSpline_gaussian():
for i in range(ndim):
assert sample_var[i] > sample_var_concentrated[i], "Reducing temperature increases variance in dimension " + str(i)

def test_model_serialization():
# Define the number of dimensions and the mean of the Gaussian
ndim = 2
num_samples = 100
epochs_NVP = 50
epochs_spline = 5
# Initialize a PRNG key (you can use any valid key)
key = jax.random.PRNGKey(0)
mean = jnp.zeros(ndim)
cov = jnp.eye(ndim)

def standard_nd_gaussian_pdf(x):
"""
Calculate the probability density function (PDF) of an n-dimensional Gaussian
distribution with zero mean and unit covariance.
Parameters:
- x: Input vector of length n.
Returns:
- pdf: log PDF value at input vector x.
"""
n = len(x)

# The normalizing constant (coefficient)
C = -jnp.log(2 * jnp.pi)*n/2

# Calculate the Mahalanobis distance
mahalanobis_dist = jnp.dot(x, x)
# Generate random samples from the 2D Gaussian distribution
samples = jax.random.multivariate_normal(key, mean, cov, shape=(num_samples,))

# Calculate the PDF value
pdf = C - 0.5 * mahalanobis_dist

return pdf
# NVP params
n_scaled = 13
n_unscaled = 6

#Spline params
n_layers = 13
n_bins = 8
hidden_size = [64, 64]
spline_range = (-10.0, 10.0)

# Optimizer params
learning_rate = 0.01
momentum = 0.8
standardize = True
var_scale = 0.6

model_NVP = model_nf.RealNVPModel(ndim, n_scaled_layers=n_scaled, n_unscaled_layers=n_unscaled, learning_rate = learning_rate, momentum= momentum, standardize=standardize, temperature=var_scale)

model_NVP.fit(samples, epochs=epochs_NVP)

# Serialize model
model_NVP.serialize(".test.dat")

# Deserialize model
model_NVP2 = model_nf.RealNVPModel.deserialize(".test.dat")

assert model_NVP2.ndim == model_NVP.ndim
assert model_NVP2.is_fitted() == model_NVP.is_fitted()
assert model_NVP2.n_scaled_layers == model_NVP.n_scaled_layers
assert model_NVP2.n_unscaled_layers == model_NVP.n_unscaled_layers
assert model_NVP2.learning_rate == model_NVP.learning_rate
assert model_NVP2.momentum == model_NVP.momentum
assert model_NVP2.standardize == model_NVP.standardize
assert model_NVP2.temperature == model_NVP.temperature

test = jnp.array([jnp.ones(ndim)])
assert model_NVP2.predict(test) == model_NVP.predict(test), "Prediction for deserialized model is " + str(model_NVP2.predict(test)) + ", not equal to " + str(model_NVP.predict(test))

model_spline = model_nf.RQSplineFlow(ndim, n_layers = n_layers, n_bins = n_bins, hidden_size = hidden_size, spline_range = spline_range, standardize = standardize, learning_rate = learning_rate, momentum = momentum, temperature=var_scale)
model_spline.fit(samples, epochs=epochs_spline)
# Serialize model
model_spline.serialize(".test.dat")

# Deserialize model
model_spline2 = model_nf.RQSplineFlow.deserialize(".test.dat")
assert model_spline2.ndim == model_spline.ndim
assert model_spline2.is_fitted() == model_spline.is_fitted()
assert model_spline2.n_layers == model_spline.n_layers
assert model_spline2.n_bins == model_spline.n_bins
assert model_spline2.hidden_size == model_spline.hidden_size
assert model_spline2.spline_range == model_spline.spline_range
assert model_spline2.learning_rate == model_spline.learning_rate
assert model_spline2.momentum == model_spline.momentum
assert model_spline2.standardize == model_spline.standardize
assert model_spline2.temperature == model_spline.temperature

assert model_spline2.predict(test) == model_spline.predict(test)

0 comments on commit 545c7c3

Please sign in to comment.