From 9c5b341efb50f708078329fff93629f1188483b8 Mon Sep 17 00:00:00 2001 From: Nolwen Date: Tue, 19 Mar 2024 16:18:09 +0100 Subject: [PATCH] Prepare test_clone for cnn model (using data_format metadata) --- tests/conftest.py | 6 ++++-- tests/test_clone.py | 6 +++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index eae8aa17..99b82805 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1293,7 +1293,9 @@ def toy_struct_v2( return Model(x, y) @staticmethod - def toy_struct_cnn(input_shape: tuple[int, ...] = (6, 6, 2), dtype: Optional[str] = None): + def toy_struct_cnn( + input_shape: tuple[int, ...] = (6, 6, 2), dtype: Optional[str] = None, data_format="channels_last" + ): if dtype is None: dtype = keras_config.floatx() layers = [ @@ -1302,7 +1304,7 @@ def toy_struct_cnn(input_shape: tuple[int, ...] = (6, 6, 2), dtype: Optional[str 10, kernel_size=(3, 3), activation="relu", - data_format="channels_last", + data_format=data_format, dtype=dtype, ), Flatten(dtype=dtype), diff --git a/tests/test_clone.py b/tests/test_clone.py index 55943990..b46415c5 100644 --- a/tests/test_clone.py +++ b/tests/test_clone.py @@ -61,7 +61,11 @@ def test_clone( decimal = 4 # keras model to convert - keras_model = toy_model_fn(input_shape=input_shape) + if toy_model_name == "cnn": + kwargs_toy_model = dict(data_format=model_decomon_input_metadata["data_format"]) + else: + kwargs_toy_model = {} + keras_model = toy_model_fn(input_shape=input_shape, **kwargs_toy_model) # conversion decomon_model = clone(model=keras_model, slope=slope, perturbation_domain=perturbation_domain, method=method)