Skip to content

Commit

Permalink
Fix int16/float16 support in ops.arange in TF.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jan 25, 2024
1 parent 06cde60 commit d55a293
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
8 changes: 7 additions & 1 deletion keras/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,13 @@ def arange(start, stop=None, step=1, dtype=None):
dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
dtype = dtypes.result_type(*dtypes_to_resolve)
dtype = standardize_dtype(dtype)
return tf.range(start, stop, delta=step, dtype=dtype)
try:
out = tf.range(start, stop, delta=step, dtype=dtype)
except tf.errors.NotFoundError:
# Some dtypes may not work in eager mode on CPU or GPU.
out = tf.range(start, stop, delta=step, dtype="float32")
out = tf.cast(out, dtype)
return out


@sparse.densifying_unary(0.5 * np.pi)
Expand Down
2 changes: 2 additions & 0 deletions keras/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4929,6 +4929,8 @@ def test_argsort(self, dtype):
(0.0, 10, 1, None),
(10, None, 1, "float32"),
(10, None, 1, "int32"),
(10, None, 1, "int16"),
(10, None, 1, "float16"),
)
def test_arange(self, start, stop, step, dtype):
import jax.numpy as jnp
Expand Down

0 comments on commit d55a293

Please sign in to comment.