Skip to content

Commit

Permalink
Change training params.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Feb 1, 2024
1 parent cac7d71 commit 4e2d319
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tests/test_flow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

model_classes = [md.RealNVPModel, md.RQSplineModel]

models_to_test = [real_nvp_2D, spline_4D, spline_3D]
models_to_test = [real_nvp_2D, spline_4D]
models_to_test1 = [spline_4D, spline_3D]
gaussian_var = [0.1,0.5, 1.,10.]

# Make models for serialization tests
Expand Down Expand Up @@ -151,17 +152,17 @@ def test_flow_is_fitted(model):
assert model.is_fitted() == True


@pytest.mark.parametrize("model", models_to_test)
@pytest.mark.parametrize("model", models_to_test1)
@pytest.mark.parametrize("var", gaussian_var)
def test_flows_gaussian_pdf(model, var):
# Define the number of dimensions and the mean of the Gaussian
ndim = model.ndim
num_samples = 10000
num_samples = 20000

if isinstance(model, md.RealNVPModel):
epochs = 160
epochs = 200
elif isinstance(model, md.RQSplineModel):
epochs = 40
epochs = 60

# Initialize a PRNG key (you can use any valid key)
key = jax.random.PRNGKey(0)
Expand Down

0 comments on commit 4e2d319

Please sign in to comment.