Skip to content

Commit

Permalink
Prepare test_clone for cnn model
Browse files Browse the repository at this point in the history
(using data_format metadata)
  • Loading branch information
nhuet committed Mar 19, 2024
1 parent f3b9754 commit 9c5b341
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,9 @@ def toy_struct_v2(
return Model(x, y)

@staticmethod
def toy_struct_cnn(input_shape: tuple[int, ...] = (6, 6, 2), dtype: Optional[str] = None):
def toy_struct_cnn(
input_shape: tuple[int, ...] = (6, 6, 2), dtype: Optional[str] = None, data_format="channels_last"
):
if dtype is None:
dtype = keras_config.floatx()
layers = [
Expand All @@ -1302,7 +1304,7 @@ def toy_struct_cnn(input_shape: tuple[int, ...] = (6, 6, 2), dtype: Optional[str
10,
kernel_size=(3, 3),
activation="relu",
data_format="channels_last",
data_format=data_format,
dtype=dtype,
),
Flatten(dtype=dtype),
Expand Down
6 changes: 5 additions & 1 deletion tests/test_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def test_clone(
decimal = 4

# keras model to convert
keras_model = toy_model_fn(input_shape=input_shape)
if toy_model_name == "cnn":
kwargs_toy_model = dict(data_format=model_decomon_input_metadata["data_format"])
else:
kwargs_toy_model = {}
keras_model = toy_model_fn(input_shape=input_shape, **kwargs_toy_model)

# conversion
decomon_model = clone(model=keras_model, slope=slope, perturbation_domain=perturbation_domain, method=method)
Expand Down

0 comments on commit 9c5b341

Please sign in to comment.