From 0acd57524b80d08f0c89743b564ced6911236bd0 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 23 Sep 2024 16:05:29 -0700 Subject: [PATCH] Avoid redundancy in file comparison --- keras/src/saving/file_editor.py | 12 +++++++++++- keras/src/saving/file_editor_test.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/keras/src/saving/file_editor.py b/keras/src/saving/file_editor.py index 59ac5ba22df..afd9cc15c79 100644 --- a/keras/src/saving/file_editor.py +++ b/keras/src/saving/file_editor.py @@ -138,13 +138,18 @@ def _compare( ref_name, error_count, match_count, + checked_paths, ): base_inner_path = inner_path for ref_key, ref_val in ref_spec.items(): inner_path = base_inner_path + "/" + ref_key + if inner_path in checked_paths: + continue + if ref_key not in target: + error_count += 1 + checked_paths.add(inner_path) if isinstance(ref_val, dict): - error_count += 1 self.console.print( f"[color(160)]...Object [bold]{inner_path}[/] " f"present in {ref_name}, " @@ -169,12 +174,14 @@ def _compare( ref_name, error_count=error_count, match_count=match_count, + checked_paths=checked_paths, ) error_count += _error_count match_count += _match_count else: if target[ref_key].shape != ref_val.shape: error_count += 1 + checked_paths.add(inner_path) self.console.print( f"[color(160)]...Weight shape mismatch " f"for [bold]{inner_path}[/][/]\n" @@ -187,6 +194,7 @@ def _compare( match_count += 1 return error_count, match_count + checked_paths = set() error_count, match_count = _compare( self.weights_dict, ref_spec, @@ -195,6 +203,7 @@ def _compare( ref_name="reference model", error_count=0, match_count=0, + checked_paths=checked_paths, ) _error_count, _ = _compare( ref_spec, @@ -204,6 +213,7 @@ def _compare( ref_name="saved file", error_count=0, match_count=0, + checked_paths=checked_paths, ) error_count += _error_count self.console.print("─────────────────────") diff --git a/keras/src/saving/file_editor_test.py b/keras/src/saving/file_editor_test.py index 1e9593b9725..1b24813877a 100644 --- a/keras/src/saving/file_editor_test.py +++ b/keras/src/saving/file_editor_test.py @@ -74,3 +74,18 @@ def test_basics(self): editor.summary() out = editor.compare_to(target_model) # Succeeds self.assertEqual(out["status"], "success") + + editor.delete_weight("dense_2", "1") + out = editor.compare_to(target_model) # Fails + self.assertEqual(out["status"], "error") + self.assertEqual(out["error_count"], 1) + + editor.add_weights("dense_2", {"1": np.zeros((7,))}) + out = editor.compare_to(target_model) # Fails + self.assertEqual(out["status"], "error") + self.assertEqual(out["error_count"], 1) + + editor.delete_weight("dense_2", "1") + editor.add_weights("dense_2", {"1": np.zeros((3,))}) + out = editor.compare_to(target_model) # Succeeds + self.assertEqual(out["status"], "success")