Skip to content

Commit

Permalink
add flow matching to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed May 29, 2024
1 parent dbcc0f4 commit 4a0c92d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def __init__(self, network: keras.Layer, **kwargs):
super().__init__(**kwargs)
self.network = network

@classmethod
def new(cls, network: str = "resnet", base_distribution: str = "normal"):
# TODO: we probably want to provide a factory method like this, since the other networks use it
# for high-level input parameters
# network = find_network(network)
return cls(network, base_distribution=base_distribution)

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "FlowMatching":
# TODO: the base distribution must be savable and loadable
Expand Down
6 changes: 6 additions & 0 deletions tests/test_networks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def coupling_flow():
return CouplingFlow.new()


@pytest.fixture()
def flow_matching():
from bayesflow.experimental.networks import FlowMatching
return FlowMatching.new()


@pytest.fixture(params=["coupling_flow"])
def inference_network(request):
return request.getfixturevalue(request.param)
Expand Down

0 comments on commit 4a0c92d

Please sign in to comment.