From 86bd12ce703f0b0f90a33d34a1d88b2d0101d966 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Fri, 29 Dec 2023 05:51:35 +0800 Subject: [PATCH] Replace `tfnp` with native tf ops wherever possible (#18998) * Add dtype test for `floor_divide` * Use native `tf.*` if possible in `backend.numpy` * Apply `result_type` to np's `floor_divide` * Increase test coverage * Fix failed test --- keras/backend/jax/numpy.py | 3 + keras/backend/numpy/numpy.py | 9 + keras/backend/tensorflow/numpy.py | 292 +++++++++++++++++++----------- keras/backend/torch/numpy.py | 11 +- keras/ops/numpy.py | 6 +- keras/ops/numpy_test.py | 86 ++++++++- 6 files changed, 298 insertions(+), 109 deletions(-) diff --git a/keras/backend/jax/numpy.py b/keras/backend/jax/numpy.py index fcb979622e3..50c81b64d91 100644 --- a/keras/backend/jax/numpy.py +++ b/keras/backend/jax/numpy.py @@ -207,6 +207,9 @@ def argmin(x, axis=None): def argsort(x, axis=-1): + x = convert_to_tensor(x) + if x.ndim == 0: + return jnp.argsort(x, axis=None) return jnp.argsort(x, axis=axis) diff --git a/keras/backend/numpy/numpy.py b/keras/backend/numpy/numpy.py index 7a72aaa87b3..bc7a1e312cf 100644 --- a/keras/backend/numpy/numpy.py +++ b/keras/backend/numpy/numpy.py @@ -1037,6 +1037,15 @@ def eye(N, M=None, k=0, dtype=None): def floor_divide(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)) + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) return np.floor_divide(x1, x2) diff --git a/keras/backend/tensorflow/numpy.py b/keras/backend/tensorflow/numpy.py index e1840207a9a..98f672dacff 100644 --- a/keras/backend/tensorflow/numpy.py +++ b/keras/backend/tensorflow/numpy.py @@ -27,7 +27,7 @@ def add(x1, x2): ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) - return tfnp.add(x1, x2) + return tf.add(x1, x2) def bincount(x, weights=None, minlength=0): @@ -83,7 +83,7 @@ def einsum(subscripts, *operands, **kwargs): dtypes_to_resolve.append(x.dtype) result_dtype = dtypes.result_type(*dtypes_to_resolve) compute_dtype = result_dtype - # TODO: tfnp.einsum doesn't support integer dtype with gpu + # TODO: tf.einsum doesn't support integer dtype with gpu if "int" in compute_dtype: compute_dtype = config.floatx() @@ -105,14 +105,14 @@ def subtract(x1, x2): ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) - return tfnp.subtract(x1, x2) + return tf.subtract(x1, x2) def matmul(x1, x2): - x1_shape = x1.shape - x2_shape = x2.shape x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) + x1_shape = x1.shape + x2_shape = x2.shape # TODO: GPU and XLA only support float types compute_dtype = dtypes.result_type(x1.dtype, x2.dtype, float) result_dtype = dtypes.result_type(x1.dtype, x2.dtype) @@ -218,7 +218,15 @@ def sparse_dense_matmul_3d(a, b): output.set_shape(output_shape) return output else: - return tf.cast(tfnp.matmul(x1, x2), result_dtype) + if x1_shape.rank == 2 and x2_shape.rank == 2: + output = tf.matmul(x1, x2) + elif x2_shape.rank == 1: + output = tf.tensordot(x1, x2, axes=1) + elif x1_shape.rank == 1: + output = tf.tensordot(x1, x2, axes=[[0], [-2]]) + else: + output = tf.matmul(x1, x2) + return tf.cast(output, result_dtype) @sparse.elementwise_binary_intersection @@ -233,7 +241,7 @@ def multiply(x1, x2): ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) - return tfnp.multiply(x1, x2) + return tf.multiply(x1, x2) def mean(x, axis=None, keepdims=False): @@ -302,13 +310,15 @@ def mean(x, axis=None, keepdims=False): x = convert_to_tensor(x) ori_dtype = standardize_dtype(x.dtype) compute_dtype = dtypes.result_type(x.dtype, "float32") - # `tfnp.mean` does not handle low precision (e.g., float16) overflow + # `tf.reduce_mean` does not handle low precision (e.g., float16) overflow # correctly, so we compute with float32 and cast back to the original type. if "int" in ori_dtype or ori_dtype == "bool": result_dtype = compute_dtype else: result_dtype = ori_dtype - output = tfnp.mean(x, axis=axis, keepdims=keepdims, dtype=compute_dtype) + output = tf.reduce_mean( + tf.cast(x, compute_dtype), axis=axis, keepdims=keepdims + ) return tf.cast(output, result_dtype) @@ -349,28 +359,30 @@ def absolute(x): dtype = standardize_dtype(x.dtype) if "uint" in dtype or dtype == "bool": return x - return tfnp.absolute(x) + return tf.abs(x) @sparse.elementwise_unary def abs(x): - return tfnp.absolute(x) + return absolute(x) def all(x, axis=None, keepdims=False): - return tfnp.all(x, axis=axis, keepdims=keepdims) + x = tf.cast(x, "bool") + return tf.reduce_all(x, axis=axis, keepdims=keepdims) def any(x, axis=None, keepdims=False): - return tfnp.any(x, axis=axis, keepdims=keepdims) + x = tf.cast(x, "bool") + return tf.reduce_any(x, axis=axis, keepdims=keepdims) def amax(x, axis=None, keepdims=False): - return tfnp.amax(x, axis=axis, keepdims=keepdims) + return max(x, axis=axis, keepdims=keepdims) def amin(x, axis=None, keepdims=False): - return tfnp.amin(x, axis=axis, keepdims=keepdims) + return min(x, axis=axis, keepdims=keepdims) def append(x1, x2, axis=None): @@ -379,7 +391,10 @@ def append(x1, x2, axis=None): dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - return tfnp.append(x1, x2, axis=axis) + if axis is None: + return tf.concat([tf.reshape(x1, [-1]), tf.reshape(x2, [-1])], axis=0) + else: + return tf.concat([x1, x2], axis=axis) def arange(start, stop=None, step=1, dtype=None): @@ -397,7 +412,7 @@ def arange(start, stop=None, step=1, dtype=None): return tf.range(start, stop, delta=step, dtype=dtype) -@sparse.densifying_unary(0.5 * tfnp.pi) +@sparse.densifying_unary(0.5 * np.pi) def arccos(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": @@ -458,7 +473,7 @@ def arctan2(x1, x2): dtype = dtypes.result_type(x1.dtype, x2.dtype, float) x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - return tfnp.arctan2(x1, x2) + return tf.math.atan2(x1, x2) @sparse.elementwise_unary @@ -473,18 +488,30 @@ def arctanh(x): def argmax(x, axis=None): - return tf.cast(tfnp.argmax(x, axis=axis), dtype="int32") + if axis is None: + x = tf.reshape(x, [-1]) + return tf.cast(tf.argmax(x, axis=axis), dtype="int32") def argmin(x, axis=None): - return tf.cast(tfnp.argmin(x, axis=axis), dtype="int32") + if axis is None: + x = tf.reshape(x, [-1]) + return tf.cast(tf.argmin(x, axis=axis), dtype="int32") def argsort(x, axis=-1): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "bool": x = tf.cast(x, "uint8") - return tf.cast(tfnp.argsort(x, axis=axis), dtype="int32") + + x_shape = x.shape + if x_shape.rank == 0: + return tf.cast([0], "int32") + + if axis is None: + x = tf.reshape(x, [-1]) + axis = 0 + return tf.argsort(x, axis=axis) def array(x, dtype=None): @@ -515,7 +542,7 @@ def average(x, axis=None, weights=None): def broadcast_to(x, shape): - return tfnp.broadcast_to(x, shape) + return tf.broadcast_to(x, shape) @sparse.elementwise_unary @@ -553,17 +580,17 @@ def concatenate(xs, axis=0): if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) xs = tf.nest.map_structure(lambda x: tf.cast(x, dtype), xs) - return tfnp.concatenate(xs, axis=axis) + return tf.concat(xs, axis=axis) @sparse.elementwise_unary def conjugate(x): - return tfnp.conjugate(x) + return tf.math.conj(x) @sparse.elementwise_unary def conj(x): - return tfnp.conjugate(x) + return tf.math.conj(x) @sparse.elementwise_unary @@ -614,17 +641,25 @@ def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): def cumprod(x, axis=None, dtype=None): - dtype = dtypes.result_type(dtype or x.dtype) - if dtype == "bool": - dtype = "int32" - return tfnp.cumprod(x, axis=axis, dtype=dtype) + x = convert_to_tensor(x, dtype=dtype) + # tf.math.cumprod doesn't support bool + if standardize_dtype(x.dtype) == "bool": + x = tf.cast(x, "int32") + if axis is None: + x = tf.reshape(x, [-1]) + axis = 0 + return tf.math.cumprod(x, axis=axis) def cumsum(x, axis=None, dtype=None): - dtype = dtypes.result_type(dtype or x.dtype) - if dtype == "bool": - dtype = "int32" - return tfnp.cumsum(x, axis=axis, dtype=dtype) + x = convert_to_tensor(x, dtype=dtype) + # tf.math.cumprod doesn't support bool + if standardize_dtype(x.dtype) == "bool": + x = tf.cast(x, "int32") + if axis is None: + x = tf.reshape(x, [-1]) + axis = 0 + return tf.math.cumsum(x, axis=axis) def diag(x, k=0): @@ -684,23 +719,30 @@ def dot(x, y): compute_dtype = dtypes.result_type(result_dtype, float) x = tf.cast(x, compute_dtype) y = tf.cast(y, compute_dtype) - return tf.cast(tfnp.dot(x, y), dtype=result_dtype) + + x_shape = x.shape + y_shape = y.shape + if x_shape.rank == 0 or y_shape.rank == 0: + output = x * y + elif y_shape.rank == 1: + output = tf.tensordot(x, y, axes=[[-1], [-1]]) + else: + output = tf.tensordot(x, y, axes=[[-1], [-2]]) + return tf.cast(output, result_dtype) def empty(shape, dtype=None): dtype = dtype or config.floatx() - return tfnp.empty(shape, dtype=dtype) + return tf.zeros(shape, dtype=dtype) def equal(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) - # tfnp handles the casting internally during comparision, but it lacks - # support for bfloat16. Therefore we explicitly cast to the same dtype. dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - return tfnp.equal(x1, x2) + return tf.equal(x1, x2) @sparse.densifying_unary(1) @@ -719,7 +761,7 @@ def expand_dims(x, axis): output = tf.sparse.expand_dims(x, axis) output.set_shape(compute_expand_dims_output_shape(x.shape, axis)) return output - return tfnp.expand_dims(x, axis) + return tf.expand_dims(x, axis) @sparse.elementwise_unary @@ -732,7 +774,10 @@ def expm1(x): def flip(x, axis=None): - return tfnp.flip(x, axis=axis) + x = convert_to_tensor(x) + if axis is None: + return tf.reverse(x, tf.range(tf.rank(x))) + return tf.reverse(x, [axis]) @sparse.elementwise_unary @@ -749,33 +794,33 @@ def floor(x): def full(shape, fill_value, dtype=None): dtype = dtype or config.floatx() - return tfnp.full(shape, fill_value, dtype=dtype) + fill_value = convert_to_tensor(fill_value, dtype) + return tf.broadcast_to(fill_value, shape) def full_like(x, fill_value, dtype=None): - return tfnp.full_like(x, fill_value, dtype=dtype) + x = convert_to_tensor(x) + dtype = dtypes.result_type(dtype or x.dtype) + fill_value = convert_to_tensor(fill_value, dtype) + return tf.broadcast_to(fill_value, tf.shape(x)) def greater(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) - # tfnp handles the casting internally during comparision, but it lacks - # support for bfloat16. Therefore we explicitly cast to the same dtype. dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - return tfnp.greater(x1, x2) + return tf.greater(x1, x2) def greater_equal(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) - # tfnp handles the casting internally during comparision, but it lacks - # support for bfloat16. Therefore we explicitly cast to the same dtype. dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - return tfnp.greater_equal(x1, x2) + return tf.greater_equal(x1, x2) def hstack(xs): @@ -783,28 +828,35 @@ def hstack(xs): if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs) - return tfnp.hstack(xs) + rank = tf.rank(xs[0]) + return tf.cond( + tf.equal(rank, 1), + lambda: tf.concat(xs, axis=0), + lambda: tf.concat(xs, axis=1), + ) def identity(n, dtype=None): - dtype = dtype or config.floatx() - return tfnp.identity(n, dtype=dtype) + return eye(N=n, M=n, dtype=dtype) @sparse.elementwise_unary def imag(x): - return tfnp.imag(x) + return tf.math.imag(x) def isclose(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) - # tfnp handles the casting internally during comparision, but it lacks - # support for bfloat16. Therefore we explicitly cast to the same dtype. dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - return tfnp.isclose(x1, x2) + if "float" in dtype: + # atol defaults to 1e-08 + # rtol defaults to 1e-05 + return tf.abs(x1 - x2) <= (1e-08 + 1e-05 * tf.abs(x2)) + else: + return tf.equal(x1, x2) @sparse.densifying_unary(True) @@ -838,23 +890,19 @@ def isnan(x): def less(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) - # tfnp handles the casting internally during comparison, but it lacks - # support for bfloat16. Therefore we explicitly cast to the same dtype. dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - return tfnp.less(x1, x2) + return tf.less(x1, x2) def less_equal(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) - # tfnp handles the casting internally during comparision, but it lacks - # support for bfloat16. Therefore we explicitly cast to the same dtype. dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - return tfnp.less_equal(x1, x2) + return tf.less_equal(x1, x2) def linspace( @@ -878,7 +926,7 @@ def linspace( ) -@sparse.densifying_unary(-tfnp.inf) +@sparse.densifying_unary(-np.inf) def log(x): x = convert_to_tensor(x) dtype = ( @@ -890,7 +938,7 @@ def log(x): return tf.math.log(x) -@sparse.densifying_unary(-tfnp.inf) +@sparse.densifying_unary(-np.inf) def log10(x): x = convert_to_tensor(x) dtype = ( @@ -914,7 +962,7 @@ def log1p(x): return tf.math.log1p(x) -@sparse.densifying_unary(-tfnp.inf) +@sparse.densifying_unary(-np.inf) def log2(x): x = convert_to_tensor(x) dtype = ( @@ -944,15 +992,20 @@ def logaddexp(x1, x2): def logical_and(x1, x2): - return tfnp.logical_and(x1, x2) + x1 = tf.cast(x1, "bool") + x2 = tf.cast(x2, "bool") + return tf.logical_and(x1, x2) def logical_not(x): - return tfnp.logical_not(x) + x = tf.cast(x, "bool") + return tf.logical_not(x) def logical_or(x1, x2): - return tfnp.logical_or(x1, x2) + x1 = tf.cast(x1, "bool") + x2 = tf.cast(x2, "bool") + return tf.logical_or(x1, x2) def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): @@ -988,7 +1041,7 @@ def maximum(x1, x2): ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) - return tfnp.maximum(x1, x2) + return tf.maximum(x1, x2) def median(x, axis=None, keepdims=False): @@ -996,10 +1049,11 @@ def median(x, axis=None, keepdims=False): def meshgrid(*x, indexing="xy"): - return tfnp.meshgrid(*x, indexing=indexing) + return tf.meshgrid(*x, indexing=indexing) def min(x, axis=None, keepdims=False, initial=None): + x = convert_to_tensor(x) # The TensorFlow numpy API implementation doesn't support `initial` so we # handle it manually here. if initial is not None: @@ -1032,7 +1086,7 @@ def minimum(x1, x2): ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) - return tfnp.minimum(x1, x2) + return tf.minimum(x1, x2) def mod(x1, x2): @@ -1043,7 +1097,7 @@ def mod(x1, x2): dtype = "int32" x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - return tfnp.mod(x1, x2) + return tf.math.mod(x1, x2) def moveaxis(x, source, destination): @@ -1071,12 +1125,15 @@ def nan_to_num(x): def ndim(x): - return tfnp.ndim(x) + x = convert_to_tensor(x) + return x.ndim def nonzero(x): + x = convert_to_tensor(x) + result = tf.unstack(tf.where(tf.cast(x, "bool")), x.shape.rank, axis=1) return tf.nest.map_structure( - lambda indices: tf.cast(indices, "int32"), tfnp.nonzero(x) + lambda indices: tf.cast(indices, "int32"), result ) @@ -1086,11 +1143,11 @@ def not_equal(x1, x2): dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - return tfnp.not_equal(x1, x2) + return tf.not_equal(x1, x2) def ones_like(x, dtype=None): - return tfnp.ones_like(x, dtype=dtype) + return tf.ones_like(x, dtype=dtype) def zeros_like(x, dtype=None): @@ -1103,10 +1160,11 @@ def outer(x1, x2): dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - return tfnp.outer(x1, x2) + return tf.reshape(x1, [-1, 1]) * tf.reshape(x2, [-1]) def pad(x, pad_width, mode="constant", constant_values=None): + x = convert_to_tensor(x) kwargs = {} if constant_values is not None: if mode != "constant": @@ -1116,7 +1174,8 @@ def pad(x, pad_width, mode="constant", constant_values=None): f"Received: mode={mode}" ) kwargs["constant_values"] = constant_values - return tfnp.pad(x, pad_width, mode=mode, **kwargs) + pad_width = convert_to_tensor(pad_width, "int32") + return tf.pad(x, pad_width, mode.upper(), **kwargs) def prod(x, axis=None, keepdims=False, dtype=None): @@ -1129,7 +1188,8 @@ def prod(x, axis=None, keepdims=False, dtype=None): dtype = "int32" elif dtype in ("uint8", "uint16"): dtype = "uint32" - return tfnp.prod(x, axis=axis, keepdims=keepdims, dtype=dtype) + x = tf.cast(x, dtype) + return tf.reduce_prod(x, axis=axis, keepdims=keepdims) def _quantile(x, q, axis=None, method="linear", keepdims=False): @@ -1252,17 +1312,20 @@ def quantile(x, q, axis=None, method="linear", keepdims=False): def ravel(x): - return tfnp.ravel(x) + x = convert_to_tensor(x) + return tf.reshape(x, [-1]) @sparse.elementwise_unary def real(x): - return tfnp.real(x) + x = convert_to_tensor(x) + return tf.math.real(x) -@sparse.densifying_unary(tfnp.inf) +@sparse.densifying_unary(np.inf) def reciprocal(x): - return tfnp.reciprocal(x) + x = convert_to_tensor(x) + return tf.math.reciprocal(x) def repeat(x, repeats, axis=None): @@ -1277,6 +1340,7 @@ def repeat(x, repeats, axis=None): def reshape(x, new_shape): + x = convert_to_tensor(x) if isinstance(x, tf.SparseTensor): from keras.ops.operation_utils import compute_reshape_output_shape @@ -1286,7 +1350,7 @@ def reshape(x, new_shape): output = tf.sparse.reshape(x, new_shape) output.set_shape(output_shape) return output - return tfnp.reshape(x, new_shape) + return tf.reshape(x, new_shape) def roll(x, shift, axis=None): @@ -1327,7 +1391,8 @@ def sinh(x): def size(x): - return tfnp.size(x) + x = convert_to_tensor(x) + return tf.size(x) def sort(x, axis=-1): @@ -1363,7 +1428,7 @@ def stack(x, axis=0): if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) x = tf.nest.map_structure(lambda a: convert_to_tensor(a, dtype), x) - return tfnp.stack(x, axis=axis) + return tf.stack(x, axis=axis) def std(x, axis=None, keepdims=False): @@ -1371,7 +1436,7 @@ def std(x, axis=None, keepdims=False): ori_dtype = standardize_dtype(x.dtype) if "int" in ori_dtype or ori_dtype == "bool": x = tf.cast(x, config.floatx()) - return tfnp.std(x, axis=axis, keepdims=keepdims) + return tf.math.reduce_std(x, axis=axis, keepdims=keepdims) def swapaxes(x, axis1, axis2): @@ -1440,11 +1505,11 @@ def tensordot(x1, x2, axes=2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) result_dtype = dtypes.result_type(x1.dtype, x2.dtype) - # TODO: tfnp.tensordot only supports float types + # TODO: tf.tensordot only supports float types compute_dtype = dtypes.result_type(result_dtype, float) x1 = tf.cast(x1, compute_dtype) x2 = tf.cast(x2, compute_dtype) - return tf.cast(tfnp.tensordot(x1, x2, axes=axes), dtype=result_dtype) + return tf.cast(tf.tensordot(x1, x2, axes=axes), dtype=result_dtype) @sparse.elementwise_unary @@ -1529,11 +1594,12 @@ def vdot(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) result_dtype = dtypes.result_type(x1.dtype, x2.dtype) - # TODO: tfnp.vdot only supports float types compute_dtype = dtypes.result_type(result_dtype, float) x1 = tf.cast(x1, compute_dtype) x2 = tf.cast(x2, compute_dtype) - return tf.cast(tfnp.vdot(x1, x2), result_dtype) + x1 = tf.reshape(x1, [-1]) + x2 = tf.reshape(x2, [-1]) + return tf.cast(dot(x1, x2), result_dtype) def vstack(xs): @@ -1541,10 +1607,11 @@ def vstack(xs): if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs) - return tfnp.vstack(xs) + return tf.concat(xs, axis=0) def where(condition, x1, x2): + condition = tf.cast(condition, "bool") if x1 is not None and x2 is not None: if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) @@ -1556,7 +1623,13 @@ def where(condition, x1, x2): ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) - return tfnp.where(condition, x1, x2) + return tf.where(condition, x1, x2) + if x1 is None and x2 is None: + return nonzero(condition) + raise ValueError( + "`x1` and `x2` either both should be `None`" + " or both should have non-None value." + ) @sparse.elementwise_division @@ -1572,7 +1645,7 @@ def divide(x1, x2): ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) - return tfnp.divide(x1, x2) + return tf.divide(x1, x2) @sparse.elementwise_division @@ -1589,19 +1662,19 @@ def power(x1, x2): getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)), ) - # TODO: tfnp.power doesn't support uint* types + # TODO: tf.pow doesn't support uint* types if "uint" in dtype: x1 = convert_to_tensor(x1, "int32") x2 = convert_to_tensor(x2, "int32") - return tf.cast(tfnp.power(x1, x2), dtype) + return tf.cast(tf.pow(x1, x2), dtype) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) - return tfnp.power(x1, x2) + return tf.pow(x1, x2) @sparse.elementwise_unary def negative(x): - return tfnp.negative(x) + return tf.negative(x) @sparse.elementwise_unary @@ -1609,7 +1682,7 @@ def square(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "bool": x = tf.cast(x, "int32") - return tfnp.square(x) + return tf.square(x) @sparse.elementwise_unary @@ -1641,7 +1714,7 @@ def squeeze(x, axis=None): gather_indices.append(i) new_indices = tf.gather(x.indices, gather_indices, axis=1) return tf.SparseTensor(new_indices, x.values, tuple(new_shape)) - return tfnp.squeeze(x, axis=axis) + return tf.squeeze(x, axis=axis) def transpose(x, axes=None): @@ -1651,15 +1724,16 @@ def transpose(x, axes=None): output = tf.sparse.transpose(x, perm=axes) output.set_shape(compute_transpose_output_shape(x.shape, axes)) return output - return tfnp.transpose(x, axes=axes) + return tf.transpose(x, perm=axes) def var(x, axis=None, keepdims=False): x = convert_to_tensor(x) compute_dtype = dtypes.result_type(x.dtype, "float32") result_dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, compute_dtype) return tf.cast( - tfnp.var(x, axis=axis, keepdims=keepdims, dtype=compute_dtype), + tf.math.reduce_variance(x, axis=axis, keepdims=keepdims), result_dtype, ) @@ -1682,8 +1756,20 @@ def eye(N, M=None, k=0, dtype=None): def floor_divide(x1, x2): - return tfnp.floor_divide(x1, x2) + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return tf.math.floordiv(x1, x2) def logical_xor(x1, x2): - return tfnp.logical_xor(x1, x2) + x1 = tf.cast(x1, "bool") + x2 = tf.cast(x2, "bool") + return tf.math.logical_xor(x1, x2) diff --git a/keras/backend/torch/numpy.py b/keras/backend/torch/numpy.py index e39c91d6479..3f058f39b7e 100644 --- a/keras/backend/torch/numpy.py +++ b/keras/backend/torch/numpy.py @@ -1450,8 +1450,15 @@ def eye(N, M=None, k=None, dtype=None): def floor_divide(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.floor_divide(x1, x2) + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + return cast(torch.floor_divide(x1, x2), dtype) def logical_xor(x1, x2): diff --git a/keras/ops/numpy.py b/keras/ops/numpy.py index 63b818543ba..3461d0dc542 100644 --- a/keras/ops/numpy.py +++ b/keras/ops/numpy.py @@ -5967,7 +5967,11 @@ def compute_output_spec(self, x1, x2): x1_shape = getattr(x1, "shape", []) x2_shape = getattr(x2, "shape", []) output_shape = broadcast_shapes(x1_shape, x2_shape) - return KerasTensor(output_shape, dtype=x1.dtype) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + return KerasTensor(output_shape, dtype=output_dtype) @keras_export(["keras.ops.floor_divide", "keras.ops.numpy.floor_divide"]) diff --git a/keras/ops/numpy_test.py b/keras/ops/numpy_test.py index 0880839ec0f..0e9d80e0f9d 100644 --- a/keras/ops/numpy_test.py +++ b/keras/ops/numpy_test.py @@ -269,7 +269,7 @@ def test_where(self): self.assertEqual(knp.where(condition, x, y).shape, (2, None, 3)) self.assertEqual(knp.where(condition).shape, (2, None, 1)) - def test_floordiv(self): + def test_floor_divide(self): x = KerasTensor((None, 3)) y = KerasTensor((2, None)) self.assertEqual(knp.floor_divide(x, y).shape, (2, 3)) @@ -796,7 +796,7 @@ def test_where(self): self.assertEqual(knp.where(condition, x, y).shape, (2, 3)) self.assertAllEqual(knp.where(condition).shape, (2, 3)) - def test_floordiv(self): + def test_floor_divide(self): x = KerasTensor((2, 3)) y = KerasTensor((2, 3)) self.assertEqual(knp.floor_divide(x, y).shape, (2, 3)) @@ -1984,11 +1984,14 @@ def test_matmul(self): x = np.ones([2, 3, 4, 5]) y = np.ones([2, 3, 5, 6]) z = np.ones([5, 6]) + p = np.ones([4]) self.assertAllClose(knp.matmul(x, y), np.matmul(x, y)) self.assertAllClose(knp.matmul(x, z), np.matmul(x, z)) + self.assertAllClose(knp.matmul(p, x), np.matmul(p, x)) self.assertAllClose(knp.Matmul()(x, y), np.matmul(x, y)) self.assertAllClose(knp.Matmul()(x, z), np.matmul(x, z)) + self.assertAllClose(knp.Matmul()(p, x), np.matmul(p, x)) @parameterized.named_parameters( named_product( @@ -2570,6 +2573,11 @@ def test_where(self): self.assertAllClose(knp.where(x > 1), np.where(x > 1)) self.assertAllClose(knp.Where()(x > 1), np.where(x > 1)) + with self.assertRaisesRegexp( + ValueError, "`x1` and `x2` either both should be `None`" + ): + knp.where(x > 1, x, None) + def test_digitize(self): x = np.array([0.0, 1.0, 3.0, 1.6]) bins = np.array([0.0, 3.0, 4.5, 7.0]) @@ -2927,6 +2935,10 @@ def test_argsort(self): self.assertAllClose(knp.Argsort(axis=1)(x), np.argsort(x, axis=1)) self.assertAllClose(knp.Argsort(axis=None)(x), np.argsort(x, axis=None)) + x = np.array(1) # rank == 0 + self.assertAllClose(knp.argsort(x), np.argsort(x)) + self.assertAllClose(knp.Argsort()(x), np.argsort(x)) + def test_array(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(knp.array(x), np.array(x)) @@ -3923,7 +3935,7 @@ def test_vstack(self): self.assertAllClose(knp.vstack([x, y]), np.vstack([x, y])) self.assertAllClose(knp.Vstack()([x, y]), np.vstack([x, y])) - def test_floordiv(self): + def test_floor_divide(self): x = np.array([[1, 2, 3], [3, 2, 1]]) y = np.array([[4, 5, 6], [3, 2, 1]]) z = np.array([[[1, 2, 3], [3, 2, 1]]]) @@ -5613,6 +5625,74 @@ def test_floor(self, dtype): expected_dtype, ) + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_floor_divide(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.floor_divide(x1_jax, x2_jax).dtype + ) + + self.assertEqual( + standardize_dtype(knp.floor_divide(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.FloorDivide().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_floor_divide_python_types(self, dtype): + import jax.experimental + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.floor_divide doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit when using jax backend. + with jax.experimental.disable_x64(): + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + + # python int + expected_dtype = standardize_dtype(jnp.floor_divide(x_jax, 1).dtype) + if dtype == "float64": + expected_dtype = "float64" + elif dtype == "int64": + expected_dtype = "int64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.floor_divide(x, 1).dtype), expected_dtype + ) + self.assertEqual( + knp.FloorDivide().symbolic_call(x, 1).dtype, expected_dtype + ) + + # python float + expected_dtype = standardize_dtype( + jnp.floor_divide(x_jax, 1.0).dtype + ) + if dtype == "float64": + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.floor_divide(x, 1.0).dtype), + expected_dtype, + ) + self.assertEqual( + knp.FloorDivide().symbolic_call(x, 1.0).dtype, expected_dtype + ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_full(self, dtype): import jax.numpy as jnp