Skip to content

Commit

Permalink
Increase test coverage in common/dtypes_test.py (#18618)
Browse files Browse the repository at this point in the history
* Add tests in `common/dtypes_test.py`

* Add tests in `common/dtypes_test.py`

* Add more tests to `/common/dtypes_test.py`

* Add tests in `common/dtypes_test.py`
  • Loading branch information
Faisal-Alsrheed authored Oct 16, 2023
1 parent 41016ff commit cb5ee55
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions keras/backend/common/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from keras import backend
from keras import ops
from keras.backend.common import dtypes
from keras.backend.common.variables import ALLOWED_DTYPES
from keras.backend.torch.core import to_torch_dtype
from keras.testing import test_case
Expand Down Expand Up @@ -62,3 +63,63 @@ def test_result_type_invalid_dtypes(self):
ValueError, "Invalid `dtypes`. At least one dtype is required."
):
backend.result_type()

def test_respect_weak_type_for_bool(self):
self.assertEqual(dtypes._respect_weak_type("bool", True), "bool")

def test_respect_weak_type_for_int(self):
self.assertEqual(dtypes._respect_weak_type("int32", True), "int")

def test_respect_weak_type_for_float(self):
self.assertEqual(dtypes._respect_weak_type("float32", True), "float")

def test_resolve_weak_type_for_bfloat16(self):
self.assertEqual(dtypes._resolve_weak_type("bfloat16"), "float32")

def test_resolve_weak_type_for_bfloat16_with_precision(self):
self.assertEqual(
dtypes._resolve_weak_type("bfloat16", precision="64"), "float64"
)

def test_invalid_dtype_for_keras_promotion(self):
with self.assertRaisesRegex(
ValueError, "is not a valid dtype for Keras type promotion."
):
dtypes._least_upper_bound("invalid_dtype")

def test_resolve_weak_type_for_invalid_dtype(self):
with self.assertRaisesRegex(
ValueError, "Invalid value for argument `dtype`. Expected one of"
):
dtypes._resolve_weak_type("invalid_dtype")

def test_resolve_weak_type_for_invalid_precision(self):
with self.assertRaisesRegex(
ValueError,
"Invalid value for argument `precision`. Expected one of",
):
dtypes._resolve_weak_type("int32", precision="invalid_precision")

def test_cycle_detection_in_make_lattice_upper_bounds(self):
original_lattice_function = dtypes._type_promotion_lattice

def mock_lattice():
lattice = original_lattice_function()
lattice["int32"].append("float32")
lattice["float32"].append("int32")
return lattice

dtypes._type_promotion_lattice = mock_lattice

with self.assertRaisesRegex(
ValueError, "cycle detected in type promotion lattice for node"
):
dtypes._make_lattice_upper_bounds()

dtypes._type_promotion_lattice = original_lattice_function

def test_respect_weak_type_for_invalid_dtype(self):
with self.assertRaisesRegex(
ValueError, "Invalid value for argument `dtype`. Expected one of"
):
dtypes._respect_weak_type("invalid_dtype", True)

0 comments on commit cb5ee55

Please sign in to comment.