diff --git a/tests/test_amortizers/__init__.py b/tests/test_amortizers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_amortizers/conftest.py b/tests/test_amortizers/conftest.py new file mode 100644 index 000000000..09b0e0d4e --- /dev/null +++ b/tests/test_amortizers/conftest.py @@ -0,0 +1,47 @@ + +import keras +import pytest + +import bayesflow.experimental as bf + + +@pytest.fixture() +def summary_network(): + return None + + +@pytest.fixture() +def inference_network(): + network = keras.Sequential([ + keras.layers.Dense(10) + ]) + network.compile(loss="mse") + return network + + +@pytest.fixture(params=[bf.AmortizedPosterior, bf.AmortizedLikelihood]) +def amortizer(request, inference_network, summary_network): + Amortizer = request.param + return Amortizer(inference_network, summary_network) + + +@pytest.fixture() +def dataset(): + batch_size = 16 + batches_per_epoch = 4 + parameter_sets = batch_size * batches_per_epoch + observations_per_parameter_set = 32 + + mean = keras.random.normal(mean=0.0, stddev=0.1, shape=(parameter_sets, 2)) + std = keras.ops.exp(keras.random.normal(mean=0.0, stddev=0.1, shape=(parameter_sets, 2))) + + mean = keras.ops.repeat(mean[:, None], observations_per_parameter_set, 1) + std = keras.ops.repeat(std[:, None], observations_per_parameter_set, 1) + + noise = keras.random.normal(shape=(parameter_sets, observations_per_parameter_set, 2)) + + x = mean + std * noise + + data = dict(observables=dict(x=x), parameters=dict(mean=mean, std=std)) + + return bf.datasets.OfflineDataset(data, batch_size=batch_size, batches_per_epoch=batches_per_epoch) diff --git a/tests/test_amortizers/test_fit.py b/tests/test_amortizers/test_fit.py new file mode 100644 index 000000000..7b9e84246 --- /dev/null +++ b/tests/test_amortizers/test_fit.py @@ -0,0 +1,11 @@ +def test_compile(amortizer): + amortizer.compile(optimizer="AdamW") + + +def test_fit(amortizer, dataset): + amortizer.compile(optimizer="AdamW") + amortizer.fit(dataset) + + assert amortizer.losses is not None + +