From cb5ee5599b58e6786cdf4eac40dca7bda44dd077 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 16 Oct 2023 13:41:52 +0300 Subject: [PATCH] Increase test coverage in `common/dtypes_test.py` (#18618) * Add tests in `common/dtypes_test.py` * Add tests in `common/dtypes_test.py` * Add more tests to `/common/dtypes_test.py` * Add tests in `common/dtypes_test.py` --- keras/backend/common/dtypes_test.py | 61 +++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/keras/backend/common/dtypes_test.py b/keras/backend/common/dtypes_test.py index f3b791ccde8..8a3eeb4fd6a 100644 --- a/keras/backend/common/dtypes_test.py +++ b/keras/backend/common/dtypes_test.py @@ -2,6 +2,7 @@ from keras import backend from keras import ops +from keras.backend.common import dtypes from keras.backend.common.variables import ALLOWED_DTYPES from keras.backend.torch.core import to_torch_dtype from keras.testing import test_case @@ -62,3 +63,63 @@ def test_result_type_invalid_dtypes(self): ValueError, "Invalid `dtypes`. At least one dtype is required." ): backend.result_type() + + def test_respect_weak_type_for_bool(self): + self.assertEqual(dtypes._respect_weak_type("bool", True), "bool") + + def test_respect_weak_type_for_int(self): + self.assertEqual(dtypes._respect_weak_type("int32", True), "int") + + def test_respect_weak_type_for_float(self): + self.assertEqual(dtypes._respect_weak_type("float32", True), "float") + + def test_resolve_weak_type_for_bfloat16(self): + self.assertEqual(dtypes._resolve_weak_type("bfloat16"), "float32") + + def test_resolve_weak_type_for_bfloat16_with_precision(self): + self.assertEqual( + dtypes._resolve_weak_type("bfloat16", precision="64"), "float64" + ) + + def test_invalid_dtype_for_keras_promotion(self): + with self.assertRaisesRegex( + ValueError, "is not a valid dtype for Keras type promotion." + ): + dtypes._least_upper_bound("invalid_dtype") + + def test_resolve_weak_type_for_invalid_dtype(self): + with self.assertRaisesRegex( + ValueError, "Invalid value for argument `dtype`. Expected one of" + ): + dtypes._resolve_weak_type("invalid_dtype") + + def test_resolve_weak_type_for_invalid_precision(self): + with self.assertRaisesRegex( + ValueError, + "Invalid value for argument `precision`. Expected one of", + ): + dtypes._resolve_weak_type("int32", precision="invalid_precision") + + def test_cycle_detection_in_make_lattice_upper_bounds(self): + original_lattice_function = dtypes._type_promotion_lattice + + def mock_lattice(): + lattice = original_lattice_function() + lattice["int32"].append("float32") + lattice["float32"].append("int32") + return lattice + + dtypes._type_promotion_lattice = mock_lattice + + with self.assertRaisesRegex( + ValueError, "cycle detected in type promotion lattice for node" + ): + dtypes._make_lattice_upper_bounds() + + dtypes._type_promotion_lattice = original_lattice_function + + def test_respect_weak_type_for_invalid_dtype(self): + with self.assertRaisesRegex( + ValueError, "Invalid value for argument `dtype`. Expected one of" + ): + dtypes._respect_weak_type("invalid_dtype", True)