From 89b38241e07602217d7309243842cc8f0ba3c07f Mon Sep 17 00:00:00 2001 From: Shkarupa Alex Date: Tue, 3 Sep 2024 15:37:24 +0300 Subject: [PATCH 1/3] SaturateCast op --- keras/api/_tf_keras/keras/ops/__init__.py | 1 + keras/api/ops/__init__.py | 1 + keras/src/ops/core.py | 61 +++++++++++++++++++++++ keras/src/ops/core_test.py | 24 +++++++++ 4 files changed, 87 insertions(+) diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 1b106ff3313..60ab6bed2b7 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -18,6 +18,7 @@ from keras.src.ops.core import fori_loop from keras.src.ops.core import is_tensor from keras.src.ops.core import map +from keras.src.ops.core import saturate_cast from keras.src.ops.core import scan from keras.src.ops.core import scatter from keras.src.ops.core import scatter_update diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 1b106ff3313..60ab6bed2b7 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -18,6 +18,7 @@ from keras.src.ops.core import fori_loop from keras.src.ops.core import is_tensor from keras.src.ops.core import map +from keras.src.ops.core import saturate_cast from keras.src.ops.core import scan from keras.src.ops.core import scatter from keras.src.ops.core import scatter_update diff --git a/keras/src/ops/core.py b/keras/src/ops/core.py index 8f1f84884e7..4edcbf970d5 100644 --- a/keras/src/ops/core.py +++ b/keras/src/ops/core.py @@ -802,6 +802,67 @@ def cast(x, dtype): return backend.core.cast(x, dtype) +class SaturateCast(Operation): + def __init__(self, dtype): + super().__init__() + self.dtype = backend.standardize_dtype(dtype) + + def call(self, x): + return _saturate_cast(x, self.dtype) + + def compute_output_spec(self, x): + return backend.KerasTensor(shape=x.shape, dtype=self.dtype) + + +@keras_export("keras.ops.saturate_cast") +def saturate_cast(x, dtype): + """Performs a safe saturating cast to the desired dtype. + + Args: + x: A tensor or variable. + dtype: The target type. + + Returns: + A safely casted tensor of the specified `dtype`. + + Example: + + >>> x = keras.ops.arange(-258, 259) + >>> x = keras.ops.saturate_cast(x, dtype="uint8") + """ + dtype = backend.standardize_dtype(dtype) + + if any_symbolic_tensors((x,)): + return SaturateCast(dtype=dtype)(x) + return _saturate_cast(x, dtype) + + +def _saturate_cast(x, dtype): + dtype = backend.standardize_dtype(dtype) + in_dtype = backend.standardize_dtype(x.dtype) + in_info = np.iinfo(in_dtype) if "int" in in_dtype else np.finfo(in_dtype) + out_info = np.iinfo(dtype) if "int" in dtype else np.finfo(dtype) + + # The output min/max may not actually be representable in the + # in_dtype (e.g. casting float32 to uint32). This can lead to undefined + # behavior when trying to cast a value outside the valid range of the + # target type. We work around this by nudging the min/max to fall within + # the valid output range. The catch is that we may actually saturate + # to a value less than the true saturation limit, but this is the best we + # can do in order to avoid UB without backend op. + min_limit = np.maximum(in_info.min, out_info.min).astype(in_dtype) + if min_limit < out_info.min: + min_limit = np.nextafter(min_limit, 0, dtype=in_dtype) + max_limit = np.minimum(in_info.max, out_info.max).astype(in_dtype) + if max_limit > out_info.max: + max_limit = np.nextafter(max_limit, 0, dtype=in_dtype) + + # Unconditionally apply `clip` to fix `inf` behavior. + x = backend.numpy.clip(x, min_limit, max_limit) + + return cast(x, dtype) + + @keras_export("keras.ops.convert_to_tensor") def convert_to_tensor(x, dtype=None, sparse=None): """Convert a NumPy array to a tensor. diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index 675a2ab357f..45ea8800ca2 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -769,6 +769,17 @@ def test_cast_float8(self, float8_dtype): self.assertEqual(x.shape, y.shape) self.assertTrue(hasattr(x, "_keras_history")) + def test_saturate_cast(self): + x = ops.ones((2,), dtype="float32") + y = ops.saturate_cast(x, "float16") + self.assertIn("float16", str(y.dtype)) + + x = ops.KerasTensor((2,), dtype="float32") + y = ops.saturate_cast(x, "float16") + self.assertEqual("float16", y.dtype) + self.assertEqual(x.shape, y.shape) + self.assertTrue(hasattr(y, "_keras_history")) + def test_vectorized_map(self): def fn(x): return x + 1 @@ -1140,6 +1151,19 @@ def test_cast_basic_functionality(self): expected_values = x.astype(target_dtype) self.assertTrue(np.array_equal(result, expected_values)) + def test_saturate_cast_basic_functionality(self): + x = np.array([-256, 1.0, 257.0], dtype=np.float32) + target_dtype = np.uint8 + cast = core.SaturateCast(target_dtype) + result = cast.call(x) + result = core.convert_to_numpy(result) + self.assertEqual(result.dtype, target_dtype) + # Check that the values are the same + expected_values = np.clip(x, 0, 255).astype(target_dtype) + print(result) + print(expected_values) + self.assertTrue(np.array_equal(result, expected_values)) + def test_cond_check_output_spec_list_tuple(self): cond_op = core.Cond() mock_spec = Mock(dtype="float32", shape=(2, 2)) From 9c529d951982f4ea38a229276c9d6656b341ae97 Mon Sep 17 00:00:00 2001 From: Shkarupa Alex Date: Wed, 4 Sep 2024 22:46:00 +0300 Subject: [PATCH 2/3] What a saturate cast is and when users should want to do it --- keras/src/ops/core.py | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/keras/src/ops/core.py b/keras/src/ops/core.py index 4edcbf970d5..9de471c0a1b 100644 --- a/keras/src/ops/core.py +++ b/keras/src/ops/core.py @@ -818,6 +818,12 @@ def compute_output_spec(self, x): def saturate_cast(x, dtype): """Performs a safe saturating cast to the desired dtype. + Saturating cast prevents data type overflow when casting to `dtype` with + smaller values range. E.g. + `ops.cast(ops.cast([-1, 256], "float32"), "uint8")` returns `[255, 0]`, + but `ops.saturate_cast(ops.cast([-1, 256], "float32"), "uint8")` returns + `[0, 255]`. + Args: x: A tensor or variable. dtype: The target type. @@ -827,8 +833,33 @@ def saturate_cast(x, dtype): Example: - >>> x = keras.ops.arange(-258, 259) - >>> x = keras.ops.saturate_cast(x, dtype="uint8") + Image resizing with bicubic interpolation may produce values outside + original range. + >>> image2x2 = np.array([0, 1, 254, 255], dtype="uint8").reshape(1, 2, 2, 1) + >>> image4x4 = tf.image.resize(image2x2, (4, 4), method="bicubic") + >>> print(image4x4.numpy().squeeze()) + >>> # [[-22.500004 -22.204624 -21.618908 -21.32353 ] + >>> # [ 52.526054 52.82143 53.407146 53.70253 ] + >>> # [201.29752 201.59288 202.17859 202.47395 ] + >>> # [276.32355 276.61893 277.20465 277.50006 ]] + + Casting this resized image back to `uint8` will cause overflow. + >>> image4x4_casted = ops.cast(image4x4, "uint8") + >>> print(image4x4_casted.numpy().squeeze()) + >>> # [[234 234 235 235] + >>> # [ 52 52 53 53] + >>> # [201 201 202 202] + >>> # [ 20 20 21 21]] + + Saturate casting to `uint8` will clip values to `uint8` range before + casting and will not cause overflow. + >>> image4x4_saturate_casted = ops.saturate_cast(image4x4, "uint8") + >>> print(image4x4_saturate_casted.numpy().squeeze()) + >>> # [[ 0 0 0 0] + >>> # [ 52 52 53 53] + >>> # [201 201 202 202] + >>> # [255 255 255 255]] + """ dtype = backend.standardize_dtype(dtype) From 2e815aef9475064d806e5ffdc62bea2907661f72 Mon Sep 17 00:00:00 2001 From: Shkarupa Alex Date: Tue, 17 Sep 2024 12:57:05 +0300 Subject: [PATCH 3/3] Allow AREA interpolation for TF backend --- keras/src/backend/tensorflow/image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras/src/backend/tensorflow/image.py b/keras/src/backend/tensorflow/image.py index 7b5f297f975..5302ae9fb8e 100644 --- a/keras/src/backend/tensorflow/image.py +++ b/keras/src/backend/tensorflow/image.py @@ -13,6 +13,7 @@ "lanczos3", "lanczos5", "bicubic", + "area", )