Skip to content

Commit

Permalink
Remove options arg from ModelCheckpoint callback for Keras V3 saving,…
Browse files Browse the repository at this point in the history
… streamline ModelCheckpoint saving flow. Parameterize associated tests. (#18545)

* Fix legacy optimizer handling in `compile_from_config()`.

* Add associated test

* Remove options arg from ModelCheckpoint callback for Keras V3 saving.

* Update callbacks_test.py with Keras V3 integrated tests.
  • Loading branch information
nkovela1 authored Oct 5, 2023
1 parent 6db9872 commit c497ffc
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 50 deletions.
51 changes: 29 additions & 22 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,18 +1546,9 @@ def _save_model(self, epoch, batch, logs):
f"saving model to {filepath}"
)
self.best = current
if self.save_weights_only:
self.model.save_weights(
filepath,
overwrite=True,
options=self._options,
)
else:
self.model.save(
filepath,
overwrite=True,
options=self._options,
)

# Handles saving and corresponding options
self._save_handler(filepath)
else:
if self.verbose > 0:
io_utils.print_msg(
Expand All @@ -1570,16 +1561,9 @@ def _save_model(self, epoch, batch, logs):
io_utils.print_msg(
f"\nEpoch {epoch + 1}: saving model to {filepath}"
)
if self.save_weights_only:
self.model.save_weights(
filepath, overwrite=True, options=self._options
)
elif filepath.endswith(".keras"):
self.model.save(filepath, overwrite=True)
else:
self.model.save(
filepath, overwrite=True, options=self._options
)

# Handles saving and corresponding options
self._save_handler(filepath)

self._maybe_remove_file()
except IsADirectoryError: # h5py 3.x
Expand All @@ -1600,6 +1584,29 @@ def _save_model(self, epoch, batch, logs):
# Re-throw the error for any other causes.
raise e

def _save_handler(self, filepath):
if self.save_weights_only:
if filepath.endswith(".weights.h5"):
self.model.save_weights(
filepath,
overwrite=True,
)
else:
self.model.save_weights(
filepath,
overwrite=True,
options=self._options,
)
else:
if filepath.endswith(".keras"):
self.model.save(filepath, overwrite=True)
else:
self.model.save(
filepath,
overwrite=True,
options=self._options,
)

def _get_file_path(self, epoch, batch, logs):
"""Returns the file path for checkpoint."""

Expand Down
74 changes: 46 additions & 28 deletions keras/callbacks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def test_callback_list_methods(self):
)


class KerasCallbacksTest(test_combinations.TestCase):
class KerasCallbacksTest(test_combinations.TestCase, parameterized.TestCase):
def _get_model(self, input_shape=None, additional_metrics=None):
additional_metrics = additional_metrics or []
layers = [
Expand Down Expand Up @@ -886,8 +886,12 @@ def generator():
self.assertGreater(float(val_loss[0]), 0.0)

@test_combinations.run_with_all_model_types
def test_ModelCheckpoint(self):
if h5py is None:
@parameterized.named_parameters(
("h5", ".h5"),
("keras", ".keras"),
)
def test_ModelCheckpoint(self, save_format):
if save_format == ".h5" and h5py is None:
return # Skip test if models cannot be saved.

model_type = test_utils.get_model_type()
Expand Down Expand Up @@ -915,7 +919,7 @@ def test_ModelCheckpoint(self):

# Save model to a subdir inside the temp_dir so we can test
# automatic directory creation.
filepath = os.path.join(temp_dir, "subdir", "checkpoint.h5")
filepath = os.path.join(temp_dir, "subdir", "checkpoint" + save_format)
(x_train, y_train), (x_test, y_test) = test_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
test_samples=TEST_SAMPLES,
Expand Down Expand Up @@ -1040,7 +1044,9 @@ def test_ModelCheckpoint(self):
period = 2
mode = "auto"

filepath = os.path.join(temp_dir, "checkpoint.{epoch:02d}.h5")
filepath = os.path.join(
temp_dir, "checkpoint.{epoch:02d}" + save_format
)
cbks = [
keras.callbacks.ModelCheckpoint(
filepath,
Expand Down Expand Up @@ -1077,7 +1083,9 @@ def test_ModelCheckpoint(self):
# Case 7: `ModelCheckpoint` with a combination of `save_freq` and
# `period`. Though `period` is deprecated, we're testing it for
# backward-compatibility.
filepath = os.path.join(temp_dir, "checkpoint.epoch{epoch:02d}.h5")
filepath = os.path.join(
temp_dir, "checkpoint.epoch{epoch:02d}" + save_format
)
cbks = [
keras.callbacks.ModelCheckpoint(
filepath,
Expand Down Expand Up @@ -1109,7 +1117,9 @@ def test_ModelCheckpoint(self):
os.remove(filepath.format(epoch=10))

# Case 8: `ModelCheckpoint` with an integer `save_freq`
filepath = os.path.join(temp_dir, "checkpoint.epoch{epoch:02d}.h5")
filepath = os.path.join(
temp_dir, "checkpoint.epoch{epoch:02d}" + save_format
)
cbks = [
keras.callbacks.ModelCheckpoint(
filepath,
Expand Down Expand Up @@ -1169,24 +1179,31 @@ def test_ModelCheckpoint(self):
)

# Case 10: `ModelCheckpoint` with valid and invalid `options` argument.
with self.assertRaisesRegex(TypeError, "tf.train.CheckpointOptions"):
keras.callbacks.ModelCheckpoint(
filepath,
monitor=monitor,
save_best_only=save_best_only,
save_weights_only=True,
mode=mode,
options=tf.saved_model.SaveOptions(),
)
with self.assertRaisesRegex(TypeError, "tf.saved_model.SaveOptions"):
keras.callbacks.ModelCheckpoint(
filepath,
monitor=monitor,
save_best_only=save_best_only,
save_weights_only=False,
mode=mode,
options=tf.train.CheckpointOptions(),
)
if save_format == ".h5":
with self.assertRaisesRegex(
TypeError, "tf.train.CheckpointOptions"
):
keras.callbacks.ModelCheckpoint(
filepath,
monitor=monitor,
save_best_only=save_best_only,
save_weights_only=True,
mode=mode,
options=tf.saved_model.SaveOptions(),
)

with self.assertRaisesRegex(
TypeError, "tf.saved_model.SaveOptions"
):
keras.callbacks.ModelCheckpoint(
filepath,
monitor=monitor,
save_best_only=save_best_only,
save_weights_only=False,
mode=mode,
options=tf.train.CheckpointOptions(),
)

keras.callbacks.ModelCheckpoint(
filepath,
monitor=monitor,
Expand All @@ -1206,7 +1223,8 @@ def test_ModelCheckpoint(self):

# Case 11: `ModelCheckpoint` save model with batch number in filename.
filepath = os.path.join(
temp_dir, "checkpoint.epoch{epoch:02d}batch{batch:02d}.h5"
temp_dir,
"checkpoint.epoch{epoch:02d}batch{batch:02d}" + save_format,
)
cbks = [
keras.callbacks.ModelCheckpoint(
Expand Down Expand Up @@ -1261,7 +1279,7 @@ def test_ModelCheckpoint(self):
monitor = "val_acc"
initial_value_threshold = 0
save_best_only = True
filepath = os.path.join(temp_dir, "checkpoint.h5")
filepath = os.path.join(temp_dir, "checkpoint" + save_format)
cbks = [
keras.callbacks.ModelCheckpoint(
filepath,
Expand Down Expand Up @@ -1389,7 +1407,7 @@ def test_ModelCheckpoint_subclass_SavedModel_save_weights_false(self):
self.assertIn("saved_model.pb", os.listdir(filepath))

@test_utils.run_v2_only
def test_ModelCheckpoint_subclass_KerasV3_save_weights_false(self):
def test_ModelCheckpoint_subclass_KerasV3(self):
model = test_utils.get_small_subclass_mlp(NUM_HIDDEN, NUM_CLASSES)
model.compile(
loss="categorical_crossentropy",
Expand Down

0 comments on commit c497ffc

Please sign in to comment.