diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index 3ba275162..03a27c6e6 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -7,7 +7,7 @@ import numpy as np from bayesflow.types import Tensor -from bayesflow.utils import find_network, keras_kwargs +from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type from ..inference_network import InferenceNetwork @@ -88,6 +88,27 @@ def __init__( self.seed_generator = keras.random.SeedGenerator() + # serialization: store all parameters necessary to call __init__ + self.config = { + "total_steps": total_steps, + "max_time": max_time, + "sigma2": sigma2, + "eps": eps, + "s0": s0, + "s1": s1, + **kwargs, + } + self.config = serialize_value_or_type(self.config, "subnet", subnet) + + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_value_or_type(config, "subnet") + return cls(**config) + def _schedule_discretization(self, step) -> float: """Schedule function for adjusting the discretization level `N` during the course of training. diff --git a/bayesflow/networks/consistency_models/continuous_consistency_model.py b/bayesflow/networks/consistency_models/continuous_consistency_model.py index 2dc319782..c459ab535 100644 --- a/bayesflow/networks/consistency_models/continuous_consistency_model.py +++ b/bayesflow/networks/consistency_models/continuous_consistency_model.py @@ -7,7 +7,16 @@ import numpy as np from bayesflow.types import Tensor -from bayesflow.utils import jvp, concatenate, find_network, keras_kwargs, expand_right_as, expand_right_to +from bayesflow.utils import ( + jvp, + concatenate, + find_network, + keras_kwargs, + expand_right_as, + expand_right_to, + serialize_value_or_type, + deserialize_value_or_type, +) from ..inference_network import InferenceNetwork @@ -62,6 +71,22 @@ def __init__( self.seed_generator = keras.random.SeedGenerator() + # serialization: store all parameters necessary to call __init__ + self.config = { + "sigma_data": sigma_data, + **kwargs, + } + self.config = serialize_value_or_type(self.config, "subnet", subnet) + + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_value_or_type(config, "subnet") + return cls(**config) + def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs): t = np.linspace(0.0, np.pi / 2, num_steps) times = np.exp((t - np.pi / 2) * rho) * np.pi / 2 diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index c5be3bc69..a357d52d8 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -2,7 +2,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor -from bayesflow.utils import find_permutation, keras_kwargs +from bayesflow.utils import find_permutation, keras_kwargs, serialize_value_or_type, deserialize_value_or_type from .actnorm import ActNorm from .couplings import DualCoupling @@ -58,6 +58,17 @@ def __init__( self.invertible_layers.append(DualCoupling(subnet, transform, **kwargs.get("coupling_kwargs", {}))) + # serialization: store all parameters necessary to call __init__ + self.config = { + "depth": depth, + "transform": transform, + "permutation": permutation, + "use_actnorm": use_actnorm, + "base_distribution": base_distribution, + **kwargs, + } + self.config = serialize_value_or_type(self.config, "subnet", subnet) + # noinspection PyMethodOverriding def build(self, xz_shape, conditions_shape=None): super().build(xz_shape) @@ -65,6 +76,15 @@ def build(self, xz_shape, conditions_shape=None): for layer in self.invertible_layers: layer.build(xz_shape=xz_shape, conditions_shape=conditions_shape) + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_value_or_type(config, "subnet") + return cls(**config) + def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: diff --git a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py index d66d517bc..a862ad952 100644 --- a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py +++ b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py @@ -1,7 +1,7 @@ import keras from keras.saving import register_keras_serializable as serializable -from bayesflow.utils import keras_kwargs +from bayesflow.utils import keras_kwargs, serialize_value_or_type, deserialize_value_or_type from bayesflow.types import Tensor from .single_coupling import SingleCoupling from ..invertible_layer import InvertibleLayer @@ -15,6 +15,22 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar self.coupling2 = SingleCoupling(subnet, transform, **kwargs) self.pivot = None + # serialization: store all parameters necessary to call __init__ + self.config = { + "transform": transform, + **kwargs, + } + self.config = serialize_value_or_type(self.config, "subnet", subnet) + + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_value_or_type(config, "subnet") + return cls(**config) + # noinspection PyMethodOverriding def build(self, xz_shape, conditions_shape=None): self.pivot = xz_shape[-1] // 2 diff --git a/bayesflow/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/networks/coupling_flow/couplings/single_coupling.py index 694bbed2a..f4fba7cb1 100644 --- a/bayesflow/networks/coupling_flow/couplings/single_coupling.py +++ b/bayesflow/networks/coupling_flow/couplings/single_coupling.py @@ -3,7 +3,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor -from bayesflow.utils import find_network, keras_kwargs +from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type from ..invertible_layer import InvertibleLayer from ..transforms import find_transform @@ -26,6 +26,22 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar output_projector_kwargs.setdefault("kernel_initializer", "zeros") self.output_projector = keras.layers.Dense(units=None, **output_projector_kwargs) + # serialization: store all parameters necessary to call __init__ + self.config = { + "transform": transform, + **kwargs, + } + self.config = serialize_value_or_type(self.config, "subnet", subnet) + + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_value_or_type(config, "subnet") + return cls(**config) + # noinspection PyMethodOverriding def build(self, x1_shape, x2_shape, conditions_shape=None): self.output_projector.units = self.transform.params_per_dim * x2_shape[-1] diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index bab2c7aeb..118d3546f 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -3,7 +3,13 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor -from bayesflow.utils import expand_right_as, keras_kwargs, optimal_transport +from bayesflow.utils import ( + expand_right_as, + keras_kwargs, + optimal_transport, + serialize_value_or_type, + deserialize_value_or_type, +) from ..inference_network import InferenceNetwork from .integrators import EulerIntegrator from .integrators import RK2Integrator @@ -52,10 +58,29 @@ def __init__( case _: raise NotImplementedError(f"No support for {integrator} integration") + # serialization: store all parameters necessary to call __init__ + self.config = { + "base_distribution": base_distribution, + "integrator": integrator, + "use_optimal_transport": use_optimal_transport, + "optimal_transport_kwargs": optimal_transport_kwargs, + **kwargs, + } + self.config = serialize_value_or_type(self.config, "subnet", subnet) + def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: super().build(xz_shape) self.integrator.build(xz_shape, conditions_shape) + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_value_or_type(config, "subnet") + return cls(**config) + def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: diff --git a/bayesflow/networks/free_form_flow/free_form_flow.py b/bayesflow/networks/free_form_flow/free_form_flow.py index c893d7df8..23c375c1f 100644 --- a/bayesflow/networks/free_form_flow/free_form_flow.py +++ b/bayesflow/networks/free_form_flow/free_form_flow.py @@ -3,7 +3,16 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor -from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant, jvp, vjp +from bayesflow.utils import ( + find_network, + keras_kwargs, + concatenate, + log_jacobian_determinant, + jvp, + vjp, + serialize_value_or_type, + deserialize_value_or_type, +) from ..inference_network import InferenceNetwork @@ -63,6 +72,26 @@ def __init__( self.seed_generator = keras.random.SeedGenerator() + # serialization: store all parameters necessary to call __init__ + self.config = { + "beta": beta, + "base_distribution": base_distribution, + "hutchinson_sampling": hutchinson_sampling, + **kwargs, + } + self.config = serialize_value_or_type(self.config, "encoder_subnet", encoder_subnet) + self.config = serialize_value_or_type(self.config, "decoder_subnet", decoder_subnet) + + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_value_or_type(config, "encoder_subnet") + config = deserialize_value_or_type(config, "decoder_subnet") + return cls(**config) + # noinspection PyMethodOverriding def build(self, xz_shape, conditions_shape=None): super().build(xz_shape) diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 6f51a0696..ec8f7fffb 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -18,6 +18,7 @@ format_bytes, parse_bytes, ) +from .serialization import serialize_value_or_type, deserialize_value_or_type from .jacobian_trace import jacobian_trace from .jacobian import compute_jacobian, log_jacobian_determinant from .jvp import jvp diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py new file mode 100644 index 000000000..3d621585c --- /dev/null +++ b/bayesflow/utils/serialization.py @@ -0,0 +1,78 @@ +import keras + + +PREFIX = "_bayesflow_" + + +def serialize_value_or_type(config, name, obj): + """Serialize an object that can be either a value or a type + and add it to a copy of the supplied dictionary. + + Parameters + ---------- + config : dict + Dictionary to add the serialized object to. This function does not + modify the dictionary in place, but returns a modified copy. + name : str + Name of the obj that should be stored. Required for later deserialization. + obj : object or type + The object to serialize. If `obj` is of type `type`, we use + `keras.saving.get_registered_name` to obtain the registered type name. + If it is not a type, we try to serialize it as a Keras object. + + Returns + ------- + updated_config : dict + Updated dictionary with a new key `"_bayesflow__type"` or + `"_bayesflow__val"`. The prefix is used to avoid name collisions, + the suffix indicates how the stored value has to be deserialized. + + Notes + ----- + We allow strings or `type` parameters at several places to instantiate objects + of a given type (e.g., `subnet` in `CouplingFlow`). As `type` objects cannot + be serialized, we have to distinguish the two cases for serialization and + deserialization. This function is a helper function to standardize and + simplify this. + """ + updated_config = config.copy() + if isinstance(obj, type): + updated_config[f"{PREFIX}{name}_type"] = keras.saving.get_registered_name(obj) + else: + updated_config[f"{PREFIX}{name}_val"] = keras.saving.serialize_keras_object(obj) + return updated_config + + +def deserialize_value_or_type(config, name): + """Deserialize an object that can be either a value or a type and add + it to the supplied dictionary. + + Parameters + ---------- + config : dict + Dictionary containing the object to deserialize. If a type was + serialized, it should contain the key `"_bayesflow__type"`. + If an object was serialized, it should contain the key + `"_bayesflow__val"`. In a copy of this dictionary, + the item will be replaced with the key `name`. + name : str + Name of the object to deserialize. + + Returns + ------- + updated_config : dict + Updated dictionary with a new key `name`, with a value that is either + a type or an object. + + See Also + -------- + `serialize_value_or_type` + """ + updated_config = config.copy() + if f"{PREFIX}{name}_type" in config: + updated_config[name] = keras.saving.get_registered_object(config[f"{PREFIX}{name}_type"]) + del updated_config[f"{PREFIX}{name}_type"] + elif f"{PREFIX}{name}_val" in config: + updated_config[name] = keras.saving.deserialize_keras_object(config[f"{PREFIX}{name}_val"]) + del updated_config[f"{PREFIX}{name}_val"] + return updated_config diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 62796f11b..d8eac634c 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -8,6 +8,36 @@ def deep_set(): return DeepSet() +# For the serialization tests, we want to test passing str and type. +# For all other tests, this is not necessary and would double test time. +# Therefore, below we specify two variants of each network, one without and +# one with a subnet parameter. The latter will only be used for the relevant +# tests. If there is a better way to set the params to a single value ("mlp") +# for a given test, maybe this can be simplified, but I did not see one. +@pytest.fixture(params=["str", "type"], scope="function") +def subnet(request): + if request.param == "str": + return "mlp" + + from bayesflow.networks import MLP + + return MLP + + +@pytest.fixture() +def flow_matching(): + from bayesflow.networks import FlowMatching + + return FlowMatching() + + +@pytest.fixture() +def flow_matching_subnet(subnet): + from bayesflow.networks import FlowMatching + + return FlowMatching(subnet=subnet) + + @pytest.fixture() def coupling_flow(): from bayesflow.networks import CouplingFlow @@ -16,10 +46,10 @@ def coupling_flow(): @pytest.fixture() -def flow_matching(): - from bayesflow.networks import FlowMatching +def coupling_flow_subnet(subnet): + from bayesflow.networks import CouplingFlow - return FlowMatching() + return CouplingFlow(subnet=subnet) @pytest.fixture() @@ -29,11 +59,23 @@ def free_form_flow(): return FreeFormFlow() +@pytest.fixture() +def free_form_flow_subnet(subnet): + from bayesflow.networks import FreeFormFlow + + return FreeFormFlow(encoder_subnet=subnet, decoder_subnet=subnet) + + @pytest.fixture(params=["coupling_flow", "flow_matching", "free_form_flow"], scope="function") def inference_network(request): return request.getfixturevalue(request.param) +@pytest.fixture(params=["coupling_flow_subnet", "flow_matching_subnet", "free_form_flow_subnet"], scope="function") +def inference_network_subnet(request): + return request.getfixturevalue(request.param) + + @pytest.fixture() def lst_net(): from bayesflow.networks import LSTNet diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index cdc4e9e09..33395cdc5 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -125,22 +125,22 @@ def f(x): assert allclose(inverse_log_density, numerical_inverse_log_density, rtol=1e-4, atol=1e-5) -def test_serialize_deserialize(inference_network, random_samples, random_conditions): +def test_serialize_deserialize(inference_network_subnet, subnet, random_samples, random_conditions): # to save, the model must be built - inference_network(random_samples, conditions=random_conditions) + inference_network_subnet(random_samples, conditions=random_conditions) - serialized = serialize(inference_network) + serialized = serialize(inference_network_subnet) deserialized = deserialize(serialized) reserialized = serialize(deserialized) assert serialized == reserialized -def test_save_and_load(tmp_path, inference_network, random_samples, random_conditions): +def test_save_and_load(tmp_path, inference_network_subnet, subnet, random_samples, random_conditions): # to save, the model must be built - inference_network(random_samples, conditions=random_conditions) + inference_network_subnet(random_samples, conditions=random_conditions) - keras.saving.save_model(inference_network, tmp_path / "model.keras") + keras.saving.save_model(inference_network_subnet, tmp_path / "model.keras") loaded = keras.saving.load_model(tmp_path / "model.keras") - assert_layers_equal(inference_network, loaded) + assert_layers_equal(inference_network_subnet, loaded)