From 936e6715b7ff3ba031288d72f395f3a7cf2a658a Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 6 Nov 2024 22:25:28 +0800 Subject: [PATCH] Suppress warnings for mismatched tuples and lists in functional models. --- keras/src/models/functional.py | 6 +++++- keras/src/models/functional_test.py | 9 ++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index 0acb84ed028..9c71308a651 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -214,8 +214,12 @@ def _assert_input_compatibility(self, *args): def _maybe_warn_inputs_struct_mismatch(self, inputs): try: + # We first normalize to tuples before performing the check to + # suppress warnings when encountering mismatched tuples and lists. tree.assert_same_structure( - inputs, self._inputs_struct, check_types=False + tree.lists_to_tuples(inputs), + tree.lists_to_tuples(self._inputs_struct), + check_types=False, ) except: model_inputs_struct = tree.map_structure( diff --git a/keras/src/models/functional_test.py b/keras/src/models/functional_test.py index ef792fd7f11..44858d33811 100644 --- a/keras/src/models/functional_test.py +++ b/keras/src/models/functional_test.py @@ -1,4 +1,5 @@ import os +import warnings import numpy as np import pytest @@ -503,13 +504,19 @@ def test_warning_for_mismatched_inputs_structure(self): model = Model({"i1": i1, "i2": i2}, outputs) with pytest.warns() as record: - model([np.ones((2, 2)), np.zeros((2, 2))]) + model.predict([np.ones((2, 2)), np.zeros((2, 2))], verbose=0) self.assertLen(record, 1) self.assertStartsWith( str(record[0].message), r"The structure of `inputs` doesn't match the expected structure:", ) + # No warning for mismatched tuples and lists. + model = Model([i1, i2], outputs) + with warnings.catch_warnings(record=True) as warning_logs: + model.predict((np.ones((2, 2)), np.zeros((2, 2))), verbose=0) + self.assertLen(warning_logs, 0) + def test_for_functional_in_sequential(self): # Test for a v3.4.1 regression. if backend.image_data_format() == "channels_first":