Skip to content

Commit

Permalink
Increase test coverage in constraints_test and `backup_and_restore_…
Browse files Browse the repository at this point in the history
…callback_test` (#18639)

* Add more tests to `constraints_test`

* Add tests to `backup_and_restore_callback_test`
  • Loading branch information
Faisal-Alsrheed authored Oct 18, 2023
1 parent 9e30f7f commit 6e54f7f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
12 changes: 12 additions & 0 deletions keras/callbacks/backup_and_restore_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,15 @@ def test_model_deleted_case_epoch(self):
verbose=0,
)
self.assertFalse(file_utils.exists(backup_dir))

def test_backup_dir_empty_error(self):
with self.assertRaisesRegex(
ValueError, expected_regex="Empty `backup_dir` argument passed"
):
callbacks.BackupAndRestore(backup_dir="", save_freq="epoch")

def test_backup_dir_none_error(self):
with self.assertRaisesRegex(
ValueError, expected_regex="Empty `backup_dir` argument passed"
):
callbacks.BackupAndRestore(backup_dir=None, save_freq="epoch")
42 changes: 42 additions & 0 deletions keras/constraints/constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,45 @@ def test_get_method(self):

with self.assertRaises(ValueError):
constraints.get("typo")

def test_default_constraint_call(self):
constraint_fn = constraints.Constraint()
x = np.array([1.0, 2.0, 3.0])
output = constraint_fn(x)
self.assertAllClose(x, output)

def test_constraint_get_config(self):
constraint_fn = constraints.Constraint()
config = constraint_fn.get_config()
self.assertEqual(config, {})

def test_constraint_from_config(self):
constraint_fn = constraints.Constraint()
config = constraint_fn.get_config()
recreated_constraint_fn = constraints.Constraint.from_config(config)
self.assertIsInstance(recreated_constraint_fn, constraints.Constraint)

def test_max_norm_get_config(self):
constraint_fn = constraints.MaxNorm(max_value=3.0, axis=1)
config = constraint_fn.get_config()
expected_config = {"max_value": 3.0, "axis": 1}
self.assertEqual(config, expected_config)

def test_unit_norm_get_config(self):
constraint_fn = constraints.UnitNorm(axis=1)
config = constraint_fn.get_config()
expected_config = {"axis": 1}
self.assertEqual(config, expected_config)

def test_min_max_norm_get_config(self):
constraint_fn = constraints.MinMaxNorm(
min_value=0.5, max_value=2.0, rate=0.7, axis=1
)
config = constraint_fn.get_config()
expected_config = {
"min_value": 0.5,
"max_value": 2.0,
"rate": 0.7,
"axis": 1,
}
self.assertEqual(config, expected_config)

0 comments on commit 6e54f7f

Please sign in to comment.