From 9bfaf01dd39b84cfda0763f22c07d121541abae5 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 7 May 2024 15:41:35 +0800 Subject: [PATCH 1/5] Add `scan` --- keras/api/_tf_keras/keras/ops/__init__.py | 1 + keras/api/ops/__init__.py | 1 + keras/src/backend/jax/core.py | 6 ++ keras/src/backend/numpy/core.py | 48 +++++++++ keras/src/backend/tensorflow/core.py | 121 ++++++++++++++++++++++ keras/src/backend/torch/core.py | 48 +++++++++ keras/src/ops/core.py | 106 +++++++++++++++++++ keras/src/ops/core_test.py | 86 +++++++++++++++ 8 files changed, 417 insertions(+) diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index be8f00acb55..da6d46b0c72 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -15,6 +15,7 @@ from keras.src.ops.core import custom_gradient from keras.src.ops.core import fori_loop from keras.src.ops.core import is_tensor +from keras.src.ops.core import scan from keras.src.ops.core import scatter from keras.src.ops.core import scatter_update from keras.src.ops.core import shape diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index be8f00acb55..da6d46b0c72 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -15,6 +15,7 @@ from keras.src.ops.core import custom_gradient from keras.src.ops.core import fori_loop from keras.src.ops.core import is_tensor +from keras.src.ops.core import scan from keras.src.ops.core import scatter from keras.src.ops.core import scatter_update from keras.src.ops.core import shape diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index d0b5f9c8e44..57ef60e2c57 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -276,6 +276,12 @@ def vectorized_map(function, elements): return jax.vmap(function)(elements) +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + return jax.lax.scan( + f, init=init, xs=xs, length=length, reverse=reverse, unroll=unroll + ) + + def scatter(indices, values, shape): zeros = jnp.zeros(shape, values.dtype) key = tuple(jnp.moveaxis(indices, -1, 0)) diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py index c00b2598dfd..5c2af12a134 100644 --- a/keras/src/backend/numpy/core.py +++ b/keras/src/backend/numpy/core.py @@ -140,6 +140,54 @@ def convert_numpy_to_keras_tensor(x): return output_spec +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + # Ref: https://github.com/google/jax/blob/jaxlib-v0.4.26/jax/_src/lax/control_flow/loops.py#L105 + if not callable(f): + raise TypeError(f"`f` should be a callable. Received: f={f}") + if not isinstance(unroll, (bool, int)): + raise TypeError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) + if xs is None and length is None: + raise ValueError("Got no `xs` to scan over and `length` not provided.") + + input_is_sequence = tree.is_nested(xs) + output_is_sequence = tree.is_nested(init) + + def pack_input(x): + return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0] + + def pack_output(x): + return tree.pack_sequence_as(init, x) if output_is_sequence else x[0] + + if xs is None: + xs_flat = [] + n = int(length) + else: + xs_flat = tree.flatten(xs) + xs_flat = [convert_to_tensor(elem) for elem in xs_flat] + n = int(length) if length is not None else shape(xs_flat[0])[0] + + init_flat = tree.flatten(init) + init_flat = [convert_to_tensor(init) for init in init_flat] + init = pack_output(init_flat) + dummy_y = [np.zeros_like(init) for init in init_flat] + + carry = init + ys = [] + maybe_reversed = reversed if reverse else lambda x: x + for i in maybe_reversed(range(n)): + xs_slice = [x[i] for x in xs_flat] + packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None + carry, y = f(carry, packed_xs) + ys.append(y if y is not None else dummy_y) + stacked_y = tree.map_structure( + lambda *ys: np.stack(ys), *maybe_reversed(ys) + ) + return carry, stacked_y + + def scatter(indices, values, shape): indices = convert_to_tensor(indices) values = convert_to_tensor(values) diff --git a/keras/src/backend/tensorflow/core.py b/keras/src/backend/tensorflow/core.py index 89629215f0a..7ca92525e8d 100644 --- a/keras/src/backend/tensorflow/core.py +++ b/keras/src/backend/tensorflow/core.py @@ -210,6 +210,127 @@ def vectorized_map(function, elements): return tf.vectorized_map(function, elements) +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + # We have reimplemented `scan` to match the behavior of `jax.lax.scan` + # Ref: https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/ops/functional_ops.py#L437 + # Ref: https://github.com/google/jax/blob/jaxlib-v0.4.26/jax/_src/lax/control_flow/loops.py#L105 + if not callable(f): + raise TypeError(f"`f` should be a callable. Received: f={f}") + if not isinstance(unroll, (bool, int)): + raise TypeError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) + if xs is None and length is None: + raise ValueError("Got no `xs` to scan over and `length` not provided.") + + input_is_sequence = tree.is_nested(xs) + output_is_sequence = tree.is_nested(init) + + def pack_input(x): + return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0] + + def pack_output(x): + return tree.pack_sequence_as(init, x) if output_is_sequence else x[0] + + if xs is None: + xs_flat = [] + n = int(length) + else: + # xs_flat = flatten_input(xs) + xs_flat = tree.flatten(xs) + xs_flat = [tf.convert_to_tensor(elem) for elem in xs_flat] + n = int(length) if length is not None else tf.shape(xs_flat[0])[0] + + # TensorArrays are always flat + xs_array = [ + tf.TensorArray( + dtype=x.dtype, + size=n, + dynamic_size=False, + element_shape=x.shape[1:], + infer_shape=True, + ) + for x in xs_flat + ] + xs_array = [x_a.unstack(x) for x_a, x in zip(xs_array, xs_flat)] + + init_flat = tree.flatten(init) + carry_flat = [tf.convert_to_tensor(init) for init in init_flat] + + # Store the intermediate values + # Note: there is a constraint that the output of `f` must have the same + # shape and dtype as carry (`init`). + ys_array = [ + tf.TensorArray( + dtype=carry.dtype, + size=n, + dynamic_size=False, + element_shape=carry.shape, + infer_shape=True, + ) + for carry in carry_flat + ] + carry_array = [ + tf.TensorArray( + dtype=carry.dtype, + size=1, + dynamic_size=False, + clear_after_read=False, + element_shape=carry.shape, + infer_shape=True, + ) + for carry in carry_flat + ] + carry_array = [ + carry.write(0, c) for (carry, c) in zip(carry_array, carry_flat) + ] + + def loop_body(i, carry_array, ys_array): + packed_xs = ( + pack_input([xs.read(i) for xs in xs_array]) + if len(xs_array) > 0 + else None + ) + packed_carry = pack_output([carry.read(0) for carry in carry_array]) + + carry, ys = f(packed_carry, packed_xs) + + if ys is not None: + flat_ys = tree.flatten(ys) + ys_array = [ys.write(i, v) for (ys, v) in zip(ys_array, flat_ys)] + if carry is not None: + flat_carry = tree.flatten(carry) + carry_array = [ + carry.write(0, v) for (carry, v) in zip(carry_array, flat_carry) + ] + next_i = i + 1 if not reverse else i - 1 + return (next_i, carry_array, ys_array) + + if isinstance(unroll, bool): + unroll = max(n, 1) if unroll else 1 + + _, carry_array, ys_array = tf.while_loop( + lambda i, _1, _2: i >= 0 if reverse else i < n, + loop_body, + (n - 1 if reverse else 0, carry_array, ys_array), + parallel_iterations=unroll, + ) + + ys_flat = [ys.stack() for ys in ys_array] + carry_flat = [carry.read(0) for carry in carry_array] + if xs is not None: + n_static = xs_flat[0].get_shape().with_rank_at_least(1)[0] + if not isinstance(n_static, int): + for x in xs_flat[1:]: + n_static.assert_is_compatible_with( + x.get_shape().with_rank_at_least(1)[0] + ) + for r in ys_flat: + r.set_shape(tf.TensorShape(n_static).concatenate(r.get_shape()[1:])) + return pack_output(carry_flat), pack_output(ys_flat) + + def scatter(indices, values, shape): return tf.scatter_nd(indices, values, shape) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 68453255b1f..7f443779b8b 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -340,6 +340,54 @@ def vectorized_map(function, elements): return torch.vmap(function)(elements) +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + # Ref: https://github.com/google/jax/blob/jaxlib-v0.4.26/jax/_src/lax/control_flow/loops.py#L105 + if not callable(f): + raise TypeError(f"`f` should be a callable. Received: f={f}") + if not isinstance(unroll, (bool, int)): + raise TypeError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) + if xs is None and length is None: + raise ValueError("Got no `xs` to scan over and `length` not provided.") + + input_is_sequence = tree.is_nested(xs) + output_is_sequence = tree.is_nested(init) + + def pack_input(x): + return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0] + + def pack_output(x): + return tree.pack_sequence_as(init, x) if output_is_sequence else x[0] + + if xs is None: + xs_flat = [] + n = int(length) + else: + xs_flat = tree.flatten(xs) + xs_flat = [convert_to_tensor(elem) for elem in xs_flat] + n = int(length) if length is not None else shape(xs_flat[0])[0] + + init_flat = tree.flatten(init) + init_flat = [convert_to_tensor(init) for init in init_flat] + init = pack_output(init_flat) + dummy_y = [torch.zeros_like(init) for init in init_flat] + + carry = init + ys = [] + maybe_reversed = reversed if reverse else lambda x: x + for i in maybe_reversed(range(n)): + xs_slice = [x[i] for x in xs_flat] + packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None + carry, y = f(carry, packed_xs) + ys.append(y if y is not None else dummy_y) + stacked_y = tree.map_structure( + lambda *ys: torch.stack(ys), *maybe_reversed(ys) + ) + return carry, stacked_y + + def scatter(indices, values, shape): indices = convert_to_tensor(indices) values = convert_to_tensor(values) diff --git a/keras/src/ops/core.py b/keras/src/ops/core.py index 5f581fa8678..557edee62d2 100644 --- a/keras/src/ops/core.py +++ b/keras/src/ops/core.py @@ -1,4 +1,5 @@ """ +scan scatter scatter_update slice @@ -25,6 +26,111 @@ from keras.src.utils import traceback_utils +class Scan(Operation): + def __init__(self, reverse=False, unroll=1): + super().__init__() + self.reverse = reverse + self.unroll = unroll + + def call(self, f, init, xs, length): + return backend.core.scan( + f, init, xs, length, reverse=self.reverse, unroll=self.unroll + ) + + def compute_output_spec(self, f, init, xs, length): + if xs is None: + n = int(length) + else: + n = ( + int(length) + if length is not None + else tree.flatten(xs)[0].shape[0] + ) + + carry_spec, y_spec = backend.compute_output_spec(f, init, xs[0]) + y_spec.shape = (n,) + y_spec.shape + return carry_spec, y_spec + + +@keras_export("keras.ops.scan") +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + """Scan a function over leading array axes while carrying along state. + + When the type of `xs` is an array type or `None`, and the type of `ys` is an + array type, the semantics of `scan()` are given roughly by this Python + implementation: + + ```python + def scan(f, init, xs, length=None): + if xs is None: + xs = [None] * length + carry = init + ys = [] + for x in xs: + carry, y = f(carry, x) + ys.append(y) + return carry, np.stack(ys) + ``` + + The loop-carried value `carry` (`init`) must hold a fixed shape and dtype + across all iterations. + + In TensorFlow, `y` must match `carry` in shape and dtype. This is not + required in other backends. + + Args: + f: Callable defines the logic for each loop iteration. This accepts two + arguments where the first is a value of the loop carry and the + second is a slice of `xs` along its leading axis. + This callable returns a pair where the first represents a new value + for the loop carry and the second represents a slice of the output. + init: The initial loop carry value. This can be a scalar, tensor, or any + nested structure. It must match the structure of the first element + returned by `f`. + xs: Optional value to scan along its leading axis. This can be a tensor + or any nested structure. If `xs` is not provided, you must specify + `length` to define the number of loop iterations. + Defaults to `None`. + length: Optional integer specifying the number of loop iterations. + If `length` is not provided, it defaults to the sizes of leading + axis of the arrays in `xs`. Defaults to `None`. + reverse: Optional boolean specifying whether to run the scan iteration + forward or in reverse, equivalent to reversing the leading axes of + the arrays in both `xs` and in `ys`. + unroll: Optional positive integer or boolean specifying how many scan + iterations to unroll within a single iteration of a loop. If an + integer is provided, it determines how many unrolled loop iterations + to run within a single rolled iteration of the loop. If a boolean is + provided, it will determine if the loop is completely unrolled + (`unroll=True`) or left completely unrolled (`unroll=False`). + Note that unrolling is only supported by JAX and TensorFlow + backends. + + Returns: + A pair where the first element represents the final loop carry value and + the second element represents the stacked outputs of `f` when scanned + over the leading axis of the inputs. + + Examples: + + >>> sum_fn = lambda c, x: (c + x, c + x) + >>> init = keras.ops.array(0) + >>> xs = keras.ops.array([1, 2, 3, 4, 5]) + >>> carry, result = ops.scan(sum_fn, init, xs) + >>> carry + 15 + >>> result + [1, 3, 6, 10, 15] + """ + if any_symbolic_tensors((init, xs)): + return Scan(reverse=reverse, unroll=unroll).symbolic_call( + f, init, xs, length + ) + return backend.core.scan( + f, init, xs, length, reverse=reverse, unroll=unroll + ) + + class Scatter(Operation): def call(self, indices, values, shape): return backend.core.scatter(indices, values, shape) diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index 3993b81637b..b7450097cf1 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -19,6 +19,17 @@ class CoreOpsStaticShapeTest(testing.TestCase): + def test_scan(self): + def f(carry, xs): + xs = xs + carry + return carry, carry + + init = KerasTensor(()) + xs = KerasTensor((6,)) + carry, result = core.scan(f, init, xs) + self.assertEqual(carry.shape, ()) + self.assertEqual(result.shape, (6,)) + def test_scatter(self): indices = KerasTensor((5, 2)) values = KerasTensor((5,)) @@ -85,6 +96,69 @@ def test_unstack(self): class CoreOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): + def test_scan(self): + # Test cumsum + def cumsum(carry, xs): + carry = carry + xs + return carry, carry + + init = np.array(0, dtype="float32") + xs = np.array([1, 2, 3, 4, 10, 20], dtype="float32") + carry, result = core.scan(cumsum, init, xs) + self.assertAllClose(carry, 40.0) + self.assertAllClose(result, ops.cumsum(xs)) + + # Test reverse=True + carry, result = core.scan(cumsum, init, xs, reverse=True) + self.assertAllClose(carry, 40.0) + self.assertAllClose(result, [40, 39, 37, 34, 30, 20]) + + # Test unroll + for unroll in (True, False, 2): + carry, result = core.scan(cumsum, init, xs, unroll=unroll) + self.assertAllClose(carry, 40.0) + self.assertAllClose(result, ops.cumsum(xs)) + + # Test xs is None + def fibonaccis(carry, _): + return (carry[1], carry[0] + carry[1]), None + + init = (np.array(0, dtype="float32"), np.array(1, dtype="float32")) + carry, _ = core.scan(fibonaccis, init, length=6) + self.assertAllClose(carry, [8, 13]) + + # Test nested init + if backend.backend() != "tensorflow": + # tensorflow doesn't support arbitrary shape/dtype of the output of + # `f`. It must be the same as `init`. + def multiply_two(carry, _): + value1 = carry["value1"] + value2 = carry["value2"] + return ( + {"value1": value1 * 2, "value2": value2 * 2}, + value1 * 2 + value2 * 2, + ) + + init = {"value1": 2.0, "value2": 3.0} + carry, result = core.scan(multiply_two, init, length=3) + self.assertAllClose(carry["value1"], 16) + self.assertAllClose(carry["value2"], 24) + self.assertAllClose(result, [10, 20, 40]) + + # Test nested xs + def reduce_add(carry, xs): + value1 = xs["value1"] + value2 = xs["value2"] + return carry, value1 + value2 + + init = np.array(0, dtype="float32") + xs = { + "value1": np.array([1, 2, 3], dtype="float32"), + "value2": np.array([10, 20, 30], dtype="float32"), + } + _, result = core.scan(reduce_add, init, xs) + self.assertAllClose(result, [11, 22, 33]) + def test_scatter(self): # Test 1D indices = np.array([[1], [3], [4], [7]]) @@ -642,6 +716,18 @@ def test_convert_to_numpy(self): class CoreOpsCallsTests(testing.TestCase): + def test_scan_basic_call(self): + def cumsum(carry, xs): + carry = carry + xs + return carry, carry + + init = np.array(0, dtype="float32") + xs = np.array([1, 2, 3, 4, 10, 20], dtype="float32") + scan_op = core.Scan() + carry, result = scan_op.call(cumsum, init, xs, None) + self.assertAllClose(carry, 40.0) + self.assertAllClose(result, ops.cumsum(xs)) + def test_scatter_basic_call(self): indices = np.array([[1, 0], [0, 1]]) values = np.array([10, 20]) From 6d658c03965463c702a74c5026095ac466e044a1 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 7 May 2024 16:02:00 +0800 Subject: [PATCH 2/5] Fix lint --- keras/src/backend/numpy/core.py | 2 +- keras/src/backend/tensorflow/core.py | 3 +-- keras/src/backend/torch/core.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py index 5c2af12a134..06a60f5573e 100644 --- a/keras/src/backend/numpy/core.py +++ b/keras/src/backend/numpy/core.py @@ -141,7 +141,7 @@ def convert_numpy_to_keras_tensor(x): def scan(f, init, xs=None, length=None, reverse=False, unroll=1): - # Ref: https://github.com/google/jax/blob/jaxlib-v0.4.26/jax/_src/lax/control_flow/loops.py#L105 + # Ref: jax.lax.scan if not callable(f): raise TypeError(f"`f` should be a callable. Received: f={f}") if not isinstance(unroll, (bool, int)): diff --git a/keras/src/backend/tensorflow/core.py b/keras/src/backend/tensorflow/core.py index 7ca92525e8d..c5641eabdb5 100644 --- a/keras/src/backend/tensorflow/core.py +++ b/keras/src/backend/tensorflow/core.py @@ -212,8 +212,7 @@ def vectorized_map(function, elements): def scan(f, init, xs=None, length=None, reverse=False, unroll=1): # We have reimplemented `scan` to match the behavior of `jax.lax.scan` - # Ref: https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/ops/functional_ops.py#L437 - # Ref: https://github.com/google/jax/blob/jaxlib-v0.4.26/jax/_src/lax/control_flow/loops.py#L105 + # Ref: tf.scan, jax.lax.scan if not callable(f): raise TypeError(f"`f` should be a callable. Received: f={f}") if not isinstance(unroll, (bool, int)): diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 7f443779b8b..f5d2bae126f 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -341,7 +341,7 @@ def vectorized_map(function, elements): def scan(f, init, xs=None, length=None, reverse=False, unroll=1): - # Ref: https://github.com/google/jax/blob/jaxlib-v0.4.26/jax/_src/lax/control_flow/loops.py#L105 + # Ref: jax.lax.scan if not callable(f): raise TypeError(f"`f` should be a callable. Received: f={f}") if not isinstance(unroll, (bool, int)): From b7688f4f9fa958f1d355f950228a316157b34dfa Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 7 May 2024 16:37:25 +0800 Subject: [PATCH 3/5] Increase test coverage --- keras/src/backend/jax/core.py | 6 +++++ keras/src/backend/numpy/core.py | 11 ++++---- keras/src/backend/tensorflow/core.py | 11 ++++---- keras/src/backend/torch/core.py | 11 ++++---- keras/src/ops/core_test.py | 38 +++++++++++++++++++++++----- 5 files changed, 55 insertions(+), 22 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 3f9ce3104fc..cab6fbd482d 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -254,6 +254,12 @@ def vectorized_map(function, elements): def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + if not isinstance(unroll, bool): + if not isinstance(unroll, int) or unroll < 1: + raise TypeError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) return jax.lax.scan( f, init=init, xs=xs, length=length, reverse=reverse, unroll=unroll ) diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py index 06a60f5573e..04efc3d7472 100644 --- a/keras/src/backend/numpy/core.py +++ b/keras/src/backend/numpy/core.py @@ -144,11 +144,12 @@ def scan(f, init, xs=None, length=None, reverse=False, unroll=1): # Ref: jax.lax.scan if not callable(f): raise TypeError(f"`f` should be a callable. Received: f={f}") - if not isinstance(unroll, (bool, int)): - raise TypeError( - "`unroll` must be an positive integer or boolean. " - f"Received: unroll={unroll}" - ) + if not isinstance(unroll, bool): + if not isinstance(unroll, int) or unroll < 1: + raise TypeError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) if xs is None and length is None: raise ValueError("Got no `xs` to scan over and `length` not provided.") diff --git a/keras/src/backend/tensorflow/core.py b/keras/src/backend/tensorflow/core.py index c5641eabdb5..57cbe657d85 100644 --- a/keras/src/backend/tensorflow/core.py +++ b/keras/src/backend/tensorflow/core.py @@ -215,11 +215,12 @@ def scan(f, init, xs=None, length=None, reverse=False, unroll=1): # Ref: tf.scan, jax.lax.scan if not callable(f): raise TypeError(f"`f` should be a callable. Received: f={f}") - if not isinstance(unroll, (bool, int)): - raise TypeError( - "`unroll` must be an positive integer or boolean. " - f"Received: unroll={unroll}" - ) + if not isinstance(unroll, bool): + if not isinstance(unroll, int) or unroll < 1: + raise TypeError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) if xs is None and length is None: raise ValueError("Got no `xs` to scan over and `length` not provided.") diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index f5d2bae126f..829968d8cdd 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -344,11 +344,12 @@ def scan(f, init, xs=None, length=None, reverse=False, unroll=1): # Ref: jax.lax.scan if not callable(f): raise TypeError(f"`f` should be a callable. Received: f={f}") - if not isinstance(unroll, (bool, int)): - raise TypeError( - "`unroll` must be an positive integer or boolean. " - f"Received: unroll={unroll}" - ) + if not isinstance(unroll, bool): + if not isinstance(unroll, int) or unroll < 1: + raise TypeError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) if xs is None and length is None: raise ValueError("Got no `xs` to scan over and `length` not provided.") diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index b7450097cf1..dae9727fe16 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -707,13 +707,6 @@ def test_convert_to_tensor(self, x, dtype, expected_dtype): expected_dtype, ) - def test_convert_to_numpy(self): - x = ops.array([1, 2, 3], dtype="float32") - y = ops.convert_to_numpy(x) - self.assertIsInstance(y, np.ndarray) - # Test assignment -- should not fail. - y[0] = 1.0 - class CoreOpsCallsTests(testing.TestCase): def test_scan_basic_call(self): @@ -970,3 +963,34 @@ def test_cond_check_output_spec_tuple(self): (mock_spec,), (mock_spec, mock_spec_different) ) ) + + +class CoreOpsBehaviorTests(testing.TestCase): + def test_convert_to_numpy(self): + x = ops.array([1, 2, 3], dtype="float32") + y = ops.convert_to_numpy(x) + self.assertIsInstance(y, np.ndarray) + # Test assignment -- should not fail. + y[0] = 1.0 + + def test_scan_invalid_arguments(self): + def cumsum(carry, xs): + carry = carry + xs + return carry, carry + + init = np.array(0, dtype="float32") + xs = np.array([1, 2, 3, 4, 10, 20], dtype="float32") + + # Test non-callable + with self.assertRaisesRegex(TypeError, "should be a callable."): + core.scan(123, init, xs) + + # Test bad unroll + with self.assertRaisesRegex( + TypeError, "must be an positive integer or boolean." + ): + core.scan(cumsum, init, xs, unroll=-1) + + # Test both xs and length are None + with self.assertRaisesRegex(ValueError, "to scan over and"): + core.scan(cumsum, init, xs=None, length=None) From b99e0f20c985eae8c189f0b0d7ff77acbee289b1 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 7 May 2024 16:42:11 +0800 Subject: [PATCH 4/5] Increase test coverage --- keras/src/ops/core.py | 4 +++- keras/src/ops/core_test.py | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/keras/src/ops/core.py b/keras/src/ops/core.py index 557edee62d2..3b584129099 100644 --- a/keras/src/ops/core.py +++ b/keras/src/ops/core.py @@ -40,14 +40,16 @@ def call(self, f, init, xs, length): def compute_output_spec(self, f, init, xs, length): if xs is None: n = int(length) + x = None else: n = ( int(length) if length is not None else tree.flatten(xs)[0].shape[0] ) + x = xs[0] - carry_spec, y_spec = backend.compute_output_spec(f, init, xs[0]) + carry_spec, y_spec = backend.compute_output_spec(f, init, x) y_spec.shape = (n,) + y_spec.shape return carry_spec, y_spec diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index dae9727fe16..0a461c4b2fc 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -30,6 +30,13 @@ def f(carry, xs): self.assertEqual(carry.shape, ()) self.assertEqual(result.shape, (6,)) + def f2(carry, _): + return carry, carry + + carry, result = core.scan(f2, init, xs=None, length=3) + self.assertEqual(carry.shape, ()) + self.assertEqual(result.shape, (3,)) + def test_scatter(self): indices = KerasTensor((5, 2)) values = KerasTensor((5,)) From e91bb56d076046573bf207a21eb9110a39580bcd Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 7 May 2024 16:49:55 +0800 Subject: [PATCH 5/5] Replace `TypeError` with `ValueError` for invalid `unroll` --- keras/src/backend/jax/core.py | 2 +- keras/src/backend/numpy/core.py | 2 +- keras/src/backend/tensorflow/core.py | 2 +- keras/src/backend/torch/core.py | 2 +- keras/src/ops/core_test.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index cab6fbd482d..d4682ffff1c 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -256,7 +256,7 @@ def vectorized_map(function, elements): def scan(f, init, xs=None, length=None, reverse=False, unroll=1): if not isinstance(unroll, bool): if not isinstance(unroll, int) or unroll < 1: - raise TypeError( + raise ValueError( "`unroll` must be an positive integer or boolean. " f"Received: unroll={unroll}" ) diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py index 04efc3d7472..25fb1ea0718 100644 --- a/keras/src/backend/numpy/core.py +++ b/keras/src/backend/numpy/core.py @@ -146,7 +146,7 @@ def scan(f, init, xs=None, length=None, reverse=False, unroll=1): raise TypeError(f"`f` should be a callable. Received: f={f}") if not isinstance(unroll, bool): if not isinstance(unroll, int) or unroll < 1: - raise TypeError( + raise ValueError( "`unroll` must be an positive integer or boolean. " f"Received: unroll={unroll}" ) diff --git a/keras/src/backend/tensorflow/core.py b/keras/src/backend/tensorflow/core.py index 57cbe657d85..1f1c72a1e9f 100644 --- a/keras/src/backend/tensorflow/core.py +++ b/keras/src/backend/tensorflow/core.py @@ -217,7 +217,7 @@ def scan(f, init, xs=None, length=None, reverse=False, unroll=1): raise TypeError(f"`f` should be a callable. Received: f={f}") if not isinstance(unroll, bool): if not isinstance(unroll, int) or unroll < 1: - raise TypeError( + raise ValueError( "`unroll` must be an positive integer or boolean. " f"Received: unroll={unroll}" ) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 829968d8cdd..3fb14a75b3e 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -346,7 +346,7 @@ def scan(f, init, xs=None, length=None, reverse=False, unroll=1): raise TypeError(f"`f` should be a callable. Received: f={f}") if not isinstance(unroll, bool): if not isinstance(unroll, int) or unroll < 1: - raise TypeError( + raise ValueError( "`unroll` must be an positive integer or boolean. " f"Received: unroll={unroll}" ) diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index 0a461c4b2fc..5b731230c08 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -994,7 +994,7 @@ def cumsum(carry, xs): # Test bad unroll with self.assertRaisesRegex( - TypeError, "must be an positive integer or boolean." + ValueError, "must be an positive integer or boolean." ): core.scan(cumsum, init, xs, unroll=-1)