From d55a293702fd5f1aba1d4a546aabe48af3abf529 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 25 Jan 2024 13:17:28 -0800 Subject: [PATCH] Fix int16/float16 support in ops.arange in TF. --- keras/backend/tensorflow/numpy.py | 8 +++++++- keras/ops/numpy_test.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/keras/backend/tensorflow/numpy.py b/keras/backend/tensorflow/numpy.py index 6bf567da1e7..23b876fb8e5 100644 --- a/keras/backend/tensorflow/numpy.py +++ b/keras/backend/tensorflow/numpy.py @@ -409,7 +409,13 @@ def arange(start, stop=None, step=1, dtype=None): dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) dtype = dtypes.result_type(*dtypes_to_resolve) dtype = standardize_dtype(dtype) - return tf.range(start, stop, delta=step, dtype=dtype) + try: + out = tf.range(start, stop, delta=step, dtype=dtype) + except tf.errors.NotFoundError: + # Some dtypes may not work in eager mode on CPU or GPU. + out = tf.range(start, stop, delta=step, dtype="float32") + out = tf.cast(out, dtype) + return out @sparse.densifying_unary(0.5 * np.pi) diff --git a/keras/ops/numpy_test.py b/keras/ops/numpy_test.py index 94acbfdce5b..6c535b3b143 100644 --- a/keras/ops/numpy_test.py +++ b/keras/ops/numpy_test.py @@ -4929,6 +4929,8 @@ def test_argsort(self, dtype): (0.0, 10, 1, None), (10, None, 1, "float32"), (10, None, 1, "int32"), + (10, None, 1, "int16"), + (10, None, 1, "float16"), ) def test_arange(self, start, stop, step, dtype): import jax.numpy as jnp