diff --git a/src/decomon/layers/crown.py b/src/decomon/layers/crown.py index 0210b077..142023bb 100644 --- a/src/decomon/layers/crown.py +++ b/src/decomon/layers/crown.py @@ -38,6 +38,15 @@ def __init__( super().__init__(**kwargs) self.model_output_shape = model_output_shape + def get_config(self) -> dict[str, Any]: + config = super().get_config() + config.update( + { + "model_output_shape": self.model_output_shape, + } + ) + return config + def call(self, inputs: list[list[BackendTensor]]) -> list[BackendTensor]: """Reduce the list of crown bounds to a single one by summation. diff --git a/src/decomon/layers/fuse.py b/src/decomon/layers/fuse.py index 6e5f8351..f9447f09 100644 --- a/src/decomon/layers/fuse.py +++ b/src/decomon/layers/fuse.py @@ -99,6 +99,21 @@ def __init__( for m2_input_shape in m_1_output_shapes ] + def get_config(self) -> dict[str, Any]: + config = super().get_config() + config.update( + { + "ibp_1": self.ibp_1, + "affine_1": self.affine_1, + "ibp_2": self.ibp_2, + "affine_2": self.affine_2, + "m1_input_shape": self.m1_input_shape, + "m_1_output_shapes": self.m_1_output_shapes, + "from_linear_2": self.from_linear_2, + } + ) + return config + def build(self, input_shape: tuple[list[tuple[Optional[int], ...]], list[tuple[Optional[int], ...]]]) -> None: input_shape_1, input_shape_2 = input_shape diff --git a/src/decomon/layers/input.py b/src/decomon/layers/input.py index 560b0eae..607f8fc2 100644 --- a/src/decomon/layers/input.py +++ b/src/decomon/layers/input.py @@ -6,6 +6,7 @@ import keras import keras.ops as K from keras.layers import Layer +from keras.utils import serialize_keras_object from decomon.constants import Propagation from decomon.layers.inputs_outputs_specs import InputsOutputsSpec @@ -44,6 +45,17 @@ def __init__( layer_input_shape=tuple(), ) + def get_config(self) -> dict[str, Any]: + config = super().get_config() + config.update( + { + "ibp": self.ibp, + "affine": self.affine, + "perturbation_domain": serialize_keras_object(self.perturbation_domain), + } + ) + return config + def call(self, inputs: BackendTensor) -> list[BackendTensor]: """Generate ibp and affine bounds to propagate by the first forward layer. @@ -124,6 +136,15 @@ def __init__( self.perturbation_domain = perturbation_domain + def get_config(self) -> dict[str, Any]: + config = super().get_config() + config.update( + { + "perturbation_domain": serialize_keras_object(self.perturbation_domain), + } + ) + return config + def call(self, inputs: BackendTensor) -> list[BackendTensor]: """Generate ibp and affine bounds to propagate by the first forward layer. @@ -197,6 +218,16 @@ def __init__( self.nb_model_outputs = len(model_output_shapes) self.from_linear = from_linear + def get_config(self) -> dict[str, Any]: + config = super().get_config() + config.update( + { + "model_output_shapes": self.model_output_shapes, + "from_linear": self.from_linear, + } + ) + return config + def build(self, input_shape: list[tuple[Optional[int], ...]]) -> None: # list of tensors if len(input_shape) == 1: diff --git a/src/decomon/layers/layer.py b/src/decomon/layers/layer.py index 9a362815..a37a504b 100644 --- a/src/decomon/layers/layer.py +++ b/src/decomon/layers/layer.py @@ -4,6 +4,7 @@ import keras import keras.ops as K from keras.layers import Layer, Wrapper +from keras.utils import serialize_keras_object from decomon.constants import Propagation from decomon.layers.fuse import ( @@ -184,9 +185,10 @@ def get_config(self) -> dict[str, Any]: { "ibp": self.ibp, "affine": self.affine, - "perturbation_domain": self.perturbation_domain, + "perturbation_domain": serialize_keras_object(self.perturbation_domain), "propagation": self.propagation, - "model_output_shape_length": self.model_output_shape_length, + "model_input_shape": self.model_input_shape, + "model_output_shape": self.model_output_shape, } ) return config diff --git a/src/decomon/layers/oracle.py b/src/decomon/layers/oracle.py index 54d4e0e4..773048ca 100644 --- a/src/decomon/layers/oracle.py +++ b/src/decomon/layers/oracle.py @@ -4,6 +4,7 @@ from typing import Any, Optional, Union, overload from keras.layers import Layer +from keras.utils import serialize_keras_object from decomon.layers.inputs_outputs_specs import InputsOutputsSpec from decomon.perturbation_domain import PerturbationDomain @@ -68,6 +69,19 @@ def __init__( is_merging_layer=is_merging_layer, ) + def get_config(self) -> dict[str, Any]: + config = super().get_config() + config.update( + { + "ibp": self.ibp, + "affine": self.affine, + "perturbation_domain": serialize_keras_object(self.perturbation_domain), + "layer_input_shape": self.layer_input_shape, + "is_merging_layer": self.is_merging_layer, + } + ) + return config + def call(self, inputs: list[BackendTensor]) -> Union[list[BackendTensor], list[list[BackendTensor]]]: """Deduce ibp and affine bounds to propagate by the first forward layer. diff --git a/src/decomon/layers/output.py b/src/decomon/layers/output.py index c6d622ec..398f5f07 100644 --- a/src/decomon/layers/output.py +++ b/src/decomon/layers/output.py @@ -5,6 +5,7 @@ import keras.ops as K from keras.layers import Layer +from keras.utils import serialize_keras_object from decomon.layers.inputs_outputs_specs import InputsOutputsSpec from decomon.layers.oracle import get_forward_oracle @@ -53,6 +54,20 @@ def __init__( model_output_shapes=model_output_shapes, ) + def get_config(self) -> dict[str, Any]: + config = super().get_config() + config.update( + { + "ibp_from": self.ibp_from, + "affine_from": self.affine_from, + "ibp_to": self.ibp_to, + "affine_to": self.affine_to, + "model_output_shapes": self.model_output_shapes, + "perturbation_domain": serialize_keras_object(self.perturbation_domain), + } + ) + return config + def needs_perturbation_domain_inputs(self) -> bool: return self.inputs_outputs_spec.needs_perturbation_domain_inputs() diff --git a/src/decomon/models/models.py b/src/decomon/models/models.py index 6fff46ee..1db6e729 100644 --- a/src/decomon/models/models.py +++ b/src/decomon/models/models.py @@ -79,13 +79,10 @@ def get_config(self) -> dict[str, Any]: dict( name=self.name, perturbation_domain=serialize_keras_object(self.perturbation_domain), - dc_decomp=self.dc_decomp, method=self.method, ibp=self.ibp, affine=self.affine, - finetune=self.finetune, - shared=self.shared, - backward_bounds=self.backward_bounds, + model=serialize_keras_object(self.model), ) ) return config