Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow area resize method for TF backend #20263

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras.src.ops.core import fori_loop
from keras.src.ops.core import is_tensor
from keras.src.ops.core import map
from keras.src.ops.core import saturate_cast
from keras.src.ops.core import scan
from keras.src.ops.core import scatter
from keras.src.ops.core import scatter_update
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras.src.ops.core import fori_loop
from keras.src.ops.core import is_tensor
from keras.src.ops.core import map
from keras.src.ops.core import saturate_cast
from keras.src.ops.core import scan
from keras.src.ops.core import scatter
from keras.src.ops.core import scatter_update
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/tensorflow/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"lanczos3",
"lanczos5",
"bicubic",
"area",
)


Expand Down
92 changes: 92 additions & 0 deletions keras/src/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,98 @@ def cast(x, dtype):
return backend.core.cast(x, dtype)


class SaturateCast(Operation):
def __init__(self, dtype):
super().__init__()
self.dtype = backend.standardize_dtype(dtype)

def call(self, x):
return _saturate_cast(x, self.dtype)

def compute_output_spec(self, x):
return backend.KerasTensor(shape=x.shape, dtype=self.dtype)


@keras_export("keras.ops.saturate_cast")
def saturate_cast(x, dtype):
"""Performs a safe saturating cast to the desired dtype.

Saturating cast prevents data type overflow when casting to `dtype` with
smaller values range. E.g.
`ops.cast(ops.cast([-1, 256], "float32"), "uint8")` returns `[255, 0]`,
but `ops.saturate_cast(ops.cast([-1, 256], "float32"), "uint8")` returns
`[0, 255]`.

Args:
x: A tensor or variable.
dtype: The target type.

Returns:
A safely casted tensor of the specified `dtype`.

Example:

Image resizing with bicubic interpolation may produce values outside
original range.
>>> image2x2 = np.array([0, 1, 254, 255], dtype="uint8").reshape(1, 2, 2, 1)
>>> image4x4 = tf.image.resize(image2x2, (4, 4), method="bicubic")
>>> print(image4x4.numpy().squeeze())
>>> # [[-22.500004 -22.204624 -21.618908 -21.32353 ]
>>> # [ 52.526054 52.82143 53.407146 53.70253 ]
>>> # [201.29752 201.59288 202.17859 202.47395 ]
>>> # [276.32355 276.61893 277.20465 277.50006 ]]

Casting this resized image back to `uint8` will cause overflow.
>>> image4x4_casted = ops.cast(image4x4, "uint8")
>>> print(image4x4_casted.numpy().squeeze())
>>> # [[234 234 235 235]
>>> # [ 52 52 53 53]
>>> # [201 201 202 202]
>>> # [ 20 20 21 21]]

Saturate casting to `uint8` will clip values to `uint8` range before
casting and will not cause overflow.
>>> image4x4_saturate_casted = ops.saturate_cast(image4x4, "uint8")
>>> print(image4x4_saturate_casted.numpy().squeeze())
>>> # [[ 0 0 0 0]
>>> # [ 52 52 53 53]
>>> # [201 201 202 202]
>>> # [255 255 255 255]]

"""
dtype = backend.standardize_dtype(dtype)

if any_symbolic_tensors((x,)):
return SaturateCast(dtype=dtype)(x)
return _saturate_cast(x, dtype)


def _saturate_cast(x, dtype):
dtype = backend.standardize_dtype(dtype)
in_dtype = backend.standardize_dtype(x.dtype)
in_info = np.iinfo(in_dtype) if "int" in in_dtype else np.finfo(in_dtype)
out_info = np.iinfo(dtype) if "int" in dtype else np.finfo(dtype)

# The output min/max may not actually be representable in the
# in_dtype (e.g. casting float32 to uint32). This can lead to undefined
# behavior when trying to cast a value outside the valid range of the
# target type. We work around this by nudging the min/max to fall within
# the valid output range. The catch is that we may actually saturate
# to a value less than the true saturation limit, but this is the best we
# can do in order to avoid UB without backend op.
min_limit = np.maximum(in_info.min, out_info.min).astype(in_dtype)
if min_limit < out_info.min:
min_limit = np.nextafter(min_limit, 0, dtype=in_dtype)
max_limit = np.minimum(in_info.max, out_info.max).astype(in_dtype)
if max_limit > out_info.max:
max_limit = np.nextafter(max_limit, 0, dtype=in_dtype)

# Unconditionally apply `clip` to fix `inf` behavior.
x = backend.numpy.clip(x, min_limit, max_limit)

return cast(x, dtype)


@keras_export("keras.ops.convert_to_tensor")
def convert_to_tensor(x, dtype=None, sparse=None):
"""Convert a NumPy array to a tensor.
Expand Down
24 changes: 24 additions & 0 deletions keras/src/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,17 @@ def test_cast_float8(self, float8_dtype):
self.assertEqual(x.shape, y.shape)
self.assertTrue(hasattr(x, "_keras_history"))

def test_saturate_cast(self):
x = ops.ones((2,), dtype="float32")
y = ops.saturate_cast(x, "float16")
self.assertIn("float16", str(y.dtype))

x = ops.KerasTensor((2,), dtype="float32")
y = ops.saturate_cast(x, "float16")
self.assertEqual("float16", y.dtype)
self.assertEqual(x.shape, y.shape)
self.assertTrue(hasattr(y, "_keras_history"))

def test_vectorized_map(self):
def fn(x):
return x + 1
Expand Down Expand Up @@ -1140,6 +1151,19 @@ def test_cast_basic_functionality(self):
expected_values = x.astype(target_dtype)
self.assertTrue(np.array_equal(result, expected_values))

def test_saturate_cast_basic_functionality(self):
x = np.array([-256, 1.0, 257.0], dtype=np.float32)
target_dtype = np.uint8
cast = core.SaturateCast(target_dtype)
result = cast.call(x)
result = core.convert_to_numpy(result)
self.assertEqual(result.dtype, target_dtype)
# Check that the values are the same
expected_values = np.clip(x, 0, 255).astype(target_dtype)
print(result)
print(expected_values)
self.assertTrue(np.array_equal(result, expected_values))

def test_cond_check_output_spec_list_tuple(self):
cond_op = core.Cond()
mock_spec = Mock(dtype="float32", shape=(2, 2))
Expand Down
Loading