Skip to content

Commit

Permalink
Add checks for number of scaled layers to be non zero.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Oct 19, 2023
1 parent 0a78a3c commit 52fb4b2
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
12 changes: 10 additions & 2 deletions harmonic/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# NVP Flow
# ===============================================================================


class RealNVP(nn.Module):
"""
Real-valued non-volume preserving flow using flax and tfp-jax.
Expand All @@ -34,8 +33,9 @@ class RealNVP(nn.Module):
n_features: int
n_scaled_layers: int = 2
n_unscaled_layers: int = 4

def setup(self):

self.scaled_layers = [AffineCoupling() for i in range(self.n_scaled_layers)]
self.unscaled_layers = [
AffineCoupling(apply_scaling=False) for i in range(self.n_unscaled_layers)
Expand All @@ -54,8 +54,16 @@ def make_flow(self, var_scale=1.0):
tfb.Distribution: Base Gaussian transformed by scaled contained in the scaled_layers
attribute, followed by unscaled affine coupling layers contained in the
unscaled_layers attribute.
Raises:
ValueError: If n_scaled_layers is not positive.
"""

if self.n_scaled_layers <= 0:
raise ValueError("Number of scaled layers must be greater than zero.")

chain = []
ix = jnp.arange(self.n_features)
permutation = [ix[-1], *ix[:-1]]
Expand Down
5 changes: 5 additions & 0 deletions harmonic/model_nf.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,15 @@ def __init__(
ValueError: If the ndim_in is not positive.
ValueError: If n_scaled_layers is not positive.
"""

if ndim_in < 1:
raise ValueError("Dimension must be greater than 0.")

if n_scaled_layers <= 0:
raise ValueError("Number of scaled layers must be greater than 0.")

self.ndim = ndim_in
self.fitted = False
Expand Down
12 changes: 11 additions & 1 deletion tests/test_flow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ def test_RealNVP_constructor():
RealNVP = model_nf.RealNVPModel(-1)

ndim = 3
RealNVP = model_nf.RealNVPModel(ndim, standardize=True)

with pytest.raises(ValueError):
RealNVP = model_nf.RealNVPModel(ndim, n_scaled_layers=0)

RealNVP = model_nf.RealNVPModel(ndim, standardize=True)

with pytest.raises(ValueError):
training_samples = jnp.zeros((12,ndim+1))
Expand Down Expand Up @@ -72,6 +76,12 @@ def test_RealNVP_constructor():
RealNVP.fit(training_samples, verbose=True, epochs=5)
assert RealNVP.is_fitted() == True

def test_RealNVP_flow():

with pytest.raises(ValueError):
flow = hm.flows.RealNVP(3, n_scaled_layers=0)
flow.make_flow()


def test_RQSpline_constructor():

Expand Down

0 comments on commit 52fb4b2

Please sign in to comment.