Skip to content

Commit

Permalink
Update and add get_config() methods to decomon objects
Browse files Browse the repository at this point in the history
  • Loading branch information
nhuet committed Mar 19, 2024
1 parent c883dc9 commit 7e2c25b
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 6 deletions.
9 changes: 9 additions & 0 deletions src/decomon/layers/crown.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def __init__(
super().__init__(**kwargs)
self.model_output_shape = model_output_shape

def get_config(self) -> dict[str, Any]:
config = super().get_config()
config.update(
{
"model_output_shape": self.model_output_shape,
}
)
return config

def call(self, inputs: list[list[BackendTensor]]) -> list[BackendTensor]:
"""Reduce the list of crown bounds to a single one by summation.
Expand Down
15 changes: 15 additions & 0 deletions src/decomon/layers/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,21 @@ def __init__(
for m2_input_shape in m_1_output_shapes
]

def get_config(self) -> dict[str, Any]:
config = super().get_config()
config.update(
{
"ibp_1": self.ibp_1,
"affine_1": self.affine_1,
"ibp_2": self.ibp_2,
"affine_2": self.affine_2,
"m1_input_shape": self.m1_input_shape,
"m_1_output_shapes": self.m_1_output_shapes,
"from_linear_2": self.from_linear_2,
}
)
return config

def build(self, input_shape: tuple[list[tuple[Optional[int], ...]], list[tuple[Optional[int], ...]]]) -> None:
input_shape_1, input_shape_2 = input_shape

Expand Down
31 changes: 31 additions & 0 deletions src/decomon/layers/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import keras
import keras.ops as K
from keras.layers import Layer
from keras.utils import serialize_keras_object

from decomon.constants import Propagation
from decomon.layers.inputs_outputs_specs import InputsOutputsSpec
Expand Down Expand Up @@ -44,6 +45,17 @@ def __init__(
layer_input_shape=tuple(),
)

def get_config(self) -> dict[str, Any]:
config = super().get_config()
config.update(
{
"ibp": self.ibp,
"affine": self.affine,
"perturbation_domain": serialize_keras_object(self.perturbation_domain),
}
)
return config

def call(self, inputs: BackendTensor) -> list[BackendTensor]:
"""Generate ibp and affine bounds to propagate by the first forward layer.
Expand Down Expand Up @@ -124,6 +136,15 @@ def __init__(

self.perturbation_domain = perturbation_domain

def get_config(self) -> dict[str, Any]:
config = super().get_config()
config.update(
{
"perturbation_domain": serialize_keras_object(self.perturbation_domain),
}
)
return config

def call(self, inputs: BackendTensor) -> list[BackendTensor]:
"""Generate ibp and affine bounds to propagate by the first forward layer.
Expand Down Expand Up @@ -197,6 +218,16 @@ def __init__(
self.nb_model_outputs = len(model_output_shapes)
self.from_linear = from_linear

def get_config(self) -> dict[str, Any]:
config = super().get_config()
config.update(
{
"model_output_shapes": self.model_output_shapes,
"from_linear": self.from_linear,
}
)
return config

def build(self, input_shape: list[tuple[Optional[int], ...]]) -> None:
# list of tensors
if len(input_shape) == 1:
Expand Down
6 changes: 4 additions & 2 deletions src/decomon/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import keras
import keras.ops as K
from keras.layers import Layer, Wrapper
from keras.utils import serialize_keras_object

from decomon.constants import Propagation
from decomon.layers.fuse import (
Expand Down Expand Up @@ -184,9 +185,10 @@ def get_config(self) -> dict[str, Any]:
{
"ibp": self.ibp,
"affine": self.affine,
"perturbation_domain": self.perturbation_domain,
"perturbation_domain": serialize_keras_object(self.perturbation_domain),
"propagation": self.propagation,
"model_output_shape_length": self.model_output_shape_length,
"model_input_shape": self.model_input_shape,
"model_output_shape": self.model_output_shape,
}
)
return config
Expand Down
14 changes: 14 additions & 0 deletions src/decomon/layers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Optional, Union, overload

from keras.layers import Layer
from keras.utils import serialize_keras_object

from decomon.layers.inputs_outputs_specs import InputsOutputsSpec
from decomon.perturbation_domain import PerturbationDomain
Expand Down Expand Up @@ -68,6 +69,19 @@ def __init__(
is_merging_layer=is_merging_layer,
)

def get_config(self) -> dict[str, Any]:
config = super().get_config()
config.update(
{
"ibp": self.ibp,
"affine": self.affine,
"perturbation_domain": serialize_keras_object(self.perturbation_domain),
"layer_input_shape": self.layer_input_shape,
"is_merging_layer": self.is_merging_layer,
}
)
return config

def call(self, inputs: list[BackendTensor]) -> Union[list[BackendTensor], list[list[BackendTensor]]]:
"""Deduce ibp and affine bounds to propagate by the first forward layer.
Expand Down
15 changes: 15 additions & 0 deletions src/decomon/layers/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import keras.ops as K
from keras.layers import Layer
from keras.utils import serialize_keras_object

from decomon.layers.inputs_outputs_specs import InputsOutputsSpec
from decomon.layers.oracle import get_forward_oracle
Expand Down Expand Up @@ -53,6 +54,20 @@ def __init__(
model_output_shapes=model_output_shapes,
)

def get_config(self) -> dict[str, Any]:
config = super().get_config()
config.update(
{
"ibp_from": self.ibp_from,
"affine_from": self.affine_from,
"ibp_to": self.ibp_to,
"affine_to": self.affine_to,
"model_output_shapes": self.model_output_shapes,
"perturbation_domain": serialize_keras_object(self.perturbation_domain),
}
)
return config

def needs_perturbation_domain_inputs(self) -> bool:
return self.inputs_outputs_spec.needs_perturbation_domain_inputs()

Expand Down
5 changes: 1 addition & 4 deletions src/decomon/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,10 @@ def get_config(self) -> dict[str, Any]:
dict(
name=self.name,
perturbation_domain=serialize_keras_object(self.perturbation_domain),
dc_decomp=self.dc_decomp,
method=self.method,
ibp=self.ibp,
affine=self.affine,
finetune=self.finetune,
shared=self.shared,
backward_bounds=self.backward_bounds,
model=serialize_keras_object(self.model),
)
)
return config
Expand Down

0 comments on commit 7e2c25b

Please sign in to comment.