Skip to content

Commit

Permalink
Make flow parent class, change splines class name.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Oct 23, 2023
1 parent 97c22f8 commit a6edbfe
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 206 deletions.
2 changes: 1 addition & 1 deletion examples/normal_gamma_splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def run_example(
# Fit model
# =======================================================================
hm.logs.info_log("Fit model for {} epochs...".format(epochs_num))
model = model_nf.RQSplineFlow(
model = model_nf.RQSplineModel(
ndim, standardize=standardize, temperature=var_scale
)
model.fit(chains_train.samples, epochs=epochs_num)
Expand Down
2 changes: 1 addition & 1 deletion examples/radiata_pine_splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def run_example(
Fit model by selecing the configuration of hyper-parameters which
minimises the validation variances.
"""
model = model_nf.RQSplineFlow(ndim, standardize=standardize, temperature=var_scale)
model = model_nf.RQSplineModel(ndim, standardize=standardize, temperature=var_scale)
model.fit(chains_train.samples, epochs=epochs_num)

# ===========================================================================
Expand Down
2 changes: 1 addition & 1 deletion examples/rastrigin_splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def run_example(
"""
Fit model.
"""
model = model_nf.RQSplineFlow(
model = model_nf.RQSplineModel(
ndim,
n_layers=n_layers,
n_bins=n_bins,
Expand Down
2 changes: 1 addition & 1 deletion examples/rosenbrock_splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def run_example(
# Fit model
# =======================================================================
hm.logs.info_log("Fit model for {} epochs...".format(epochs_num))
model = model_nf.RQSplineFlow(
model = model_nf.RQSplineModel(
ndim, standardize=standardize, temperature=var_scale
)
model.fit(chains_train.samples, epochs=epochs_num)
Expand Down
Loading

0 comments on commit a6edbfe

Please sign in to comment.