diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index be8f00acb55..da6d46b0c72 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -15,6 +15,7 @@ from keras.src.ops.core import custom_gradient from keras.src.ops.core import fori_loop from keras.src.ops.core import is_tensor +from keras.src.ops.core import scan from keras.src.ops.core import scatter from keras.src.ops.core import scatter_update from keras.src.ops.core import shape diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index be8f00acb55..da6d46b0c72 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -15,6 +15,7 @@ from keras.src.ops.core import custom_gradient from keras.src.ops.core import fori_loop from keras.src.ops.core import is_tensor +from keras.src.ops.core import scan from keras.src.ops.core import scatter from keras.src.ops.core import scatter_update from keras.src.ops.core import shape diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index e255afcfe49..d4682ffff1c 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -253,6 +253,18 @@ def vectorized_map(function, elements): return jax.vmap(function)(elements) +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + if not isinstance(unroll, bool): + if not isinstance(unroll, int) or unroll < 1: + raise ValueError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) + return jax.lax.scan( + f, init=init, xs=xs, length=length, reverse=reverse, unroll=unroll + ) + + def scatter(indices, values, shape): zeros = jnp.zeros(shape, values.dtype) key = tuple(jnp.moveaxis(indices, -1, 0)) diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py index c00b2598dfd..25fb1ea0718 100644 --- a/keras/src/backend/numpy/core.py +++ b/keras/src/backend/numpy/core.py @@ -140,6 +140,55 @@ def convert_numpy_to_keras_tensor(x): return output_spec +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + # Ref: jax.lax.scan + if not callable(f): + raise TypeError(f"`f` should be a callable. Received: f={f}") + if not isinstance(unroll, bool): + if not isinstance(unroll, int) or unroll < 1: + raise ValueError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) + if xs is None and length is None: + raise ValueError("Got no `xs` to scan over and `length` not provided.") + + input_is_sequence = tree.is_nested(xs) + output_is_sequence = tree.is_nested(init) + + def pack_input(x): + return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0] + + def pack_output(x): + return tree.pack_sequence_as(init, x) if output_is_sequence else x[0] + + if xs is None: + xs_flat = [] + n = int(length) + else: + xs_flat = tree.flatten(xs) + xs_flat = [convert_to_tensor(elem) for elem in xs_flat] + n = int(length) if length is not None else shape(xs_flat[0])[0] + + init_flat = tree.flatten(init) + init_flat = [convert_to_tensor(init) for init in init_flat] + init = pack_output(init_flat) + dummy_y = [np.zeros_like(init) for init in init_flat] + + carry = init + ys = [] + maybe_reversed = reversed if reverse else lambda x: x + for i in maybe_reversed(range(n)): + xs_slice = [x[i] for x in xs_flat] + packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None + carry, y = f(carry, packed_xs) + ys.append(y if y is not None else dummy_y) + stacked_y = tree.map_structure( + lambda *ys: np.stack(ys), *maybe_reversed(ys) + ) + return carry, stacked_y + + def scatter(indices, values, shape): indices = convert_to_tensor(indices) values = convert_to_tensor(values) diff --git a/keras/src/backend/tensorflow/core.py b/keras/src/backend/tensorflow/core.py index 89629215f0a..1f1c72a1e9f 100644 --- a/keras/src/backend/tensorflow/core.py +++ b/keras/src/backend/tensorflow/core.py @@ -210,6 +210,127 @@ def vectorized_map(function, elements): return tf.vectorized_map(function, elements) +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + # We have reimplemented `scan` to match the behavior of `jax.lax.scan` + # Ref: tf.scan, jax.lax.scan + if not callable(f): + raise TypeError(f"`f` should be a callable. Received: f={f}") + if not isinstance(unroll, bool): + if not isinstance(unroll, int) or unroll < 1: + raise ValueError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) + if xs is None and length is None: + raise ValueError("Got no `xs` to scan over and `length` not provided.") + + input_is_sequence = tree.is_nested(xs) + output_is_sequence = tree.is_nested(init) + + def pack_input(x): + return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0] + + def pack_output(x): + return tree.pack_sequence_as(init, x) if output_is_sequence else x[0] + + if xs is None: + xs_flat = [] + n = int(length) + else: + # xs_flat = flatten_input(xs) + xs_flat = tree.flatten(xs) + xs_flat = [tf.convert_to_tensor(elem) for elem in xs_flat] + n = int(length) if length is not None else tf.shape(xs_flat[0])[0] + + # TensorArrays are always flat + xs_array = [ + tf.TensorArray( + dtype=x.dtype, + size=n, + dynamic_size=False, + element_shape=x.shape[1:], + infer_shape=True, + ) + for x in xs_flat + ] + xs_array = [x_a.unstack(x) for x_a, x in zip(xs_array, xs_flat)] + + init_flat = tree.flatten(init) + carry_flat = [tf.convert_to_tensor(init) for init in init_flat] + + # Store the intermediate values + # Note: there is a constraint that the output of `f` must have the same + # shape and dtype as carry (`init`). + ys_array = [ + tf.TensorArray( + dtype=carry.dtype, + size=n, + dynamic_size=False, + element_shape=carry.shape, + infer_shape=True, + ) + for carry in carry_flat + ] + carry_array = [ + tf.TensorArray( + dtype=carry.dtype, + size=1, + dynamic_size=False, + clear_after_read=False, + element_shape=carry.shape, + infer_shape=True, + ) + for carry in carry_flat + ] + carry_array = [ + carry.write(0, c) for (carry, c) in zip(carry_array, carry_flat) + ] + + def loop_body(i, carry_array, ys_array): + packed_xs = ( + pack_input([xs.read(i) for xs in xs_array]) + if len(xs_array) > 0 + else None + ) + packed_carry = pack_output([carry.read(0) for carry in carry_array]) + + carry, ys = f(packed_carry, packed_xs) + + if ys is not None: + flat_ys = tree.flatten(ys) + ys_array = [ys.write(i, v) for (ys, v) in zip(ys_array, flat_ys)] + if carry is not None: + flat_carry = tree.flatten(carry) + carry_array = [ + carry.write(0, v) for (carry, v) in zip(carry_array, flat_carry) + ] + next_i = i + 1 if not reverse else i - 1 + return (next_i, carry_array, ys_array) + + if isinstance(unroll, bool): + unroll = max(n, 1) if unroll else 1 + + _, carry_array, ys_array = tf.while_loop( + lambda i, _1, _2: i >= 0 if reverse else i < n, + loop_body, + (n - 1 if reverse else 0, carry_array, ys_array), + parallel_iterations=unroll, + ) + + ys_flat = [ys.stack() for ys in ys_array] + carry_flat = [carry.read(0) for carry in carry_array] + if xs is not None: + n_static = xs_flat[0].get_shape().with_rank_at_least(1)[0] + if not isinstance(n_static, int): + for x in xs_flat[1:]: + n_static.assert_is_compatible_with( + x.get_shape().with_rank_at_least(1)[0] + ) + for r in ys_flat: + r.set_shape(tf.TensorShape(n_static).concatenate(r.get_shape()[1:])) + return pack_output(carry_flat), pack_output(ys_flat) + + def scatter(indices, values, shape): return tf.scatter_nd(indices, values, shape) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 68453255b1f..3fb14a75b3e 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -340,6 +340,55 @@ def vectorized_map(function, elements): return torch.vmap(function)(elements) +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + # Ref: jax.lax.scan + if not callable(f): + raise TypeError(f"`f` should be a callable. Received: f={f}") + if not isinstance(unroll, bool): + if not isinstance(unroll, int) or unroll < 1: + raise ValueError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) + if xs is None and length is None: + raise ValueError("Got no `xs` to scan over and `length` not provided.") + + input_is_sequence = tree.is_nested(xs) + output_is_sequence = tree.is_nested(init) + + def pack_input(x): + return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0] + + def pack_output(x): + return tree.pack_sequence_as(init, x) if output_is_sequence else x[0] + + if xs is None: + xs_flat = [] + n = int(length) + else: + xs_flat = tree.flatten(xs) + xs_flat = [convert_to_tensor(elem) for elem in xs_flat] + n = int(length) if length is not None else shape(xs_flat[0])[0] + + init_flat = tree.flatten(init) + init_flat = [convert_to_tensor(init) for init in init_flat] + init = pack_output(init_flat) + dummy_y = [torch.zeros_like(init) for init in init_flat] + + carry = init + ys = [] + maybe_reversed = reversed if reverse else lambda x: x + for i in maybe_reversed(range(n)): + xs_slice = [x[i] for x in xs_flat] + packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None + carry, y = f(carry, packed_xs) + ys.append(y if y is not None else dummy_y) + stacked_y = tree.map_structure( + lambda *ys: torch.stack(ys), *maybe_reversed(ys) + ) + return carry, stacked_y + + def scatter(indices, values, shape): indices = convert_to_tensor(indices) values = convert_to_tensor(values) diff --git a/keras/src/ops/core.py b/keras/src/ops/core.py index 5f581fa8678..3b584129099 100644 --- a/keras/src/ops/core.py +++ b/keras/src/ops/core.py @@ -1,4 +1,5 @@ """ +scan scatter scatter_update slice @@ -25,6 +26,113 @@ from keras.src.utils import traceback_utils +class Scan(Operation): + def __init__(self, reverse=False, unroll=1): + super().__init__() + self.reverse = reverse + self.unroll = unroll + + def call(self, f, init, xs, length): + return backend.core.scan( + f, init, xs, length, reverse=self.reverse, unroll=self.unroll + ) + + def compute_output_spec(self, f, init, xs, length): + if xs is None: + n = int(length) + x = None + else: + n = ( + int(length) + if length is not None + else tree.flatten(xs)[0].shape[0] + ) + x = xs[0] + + carry_spec, y_spec = backend.compute_output_spec(f, init, x) + y_spec.shape = (n,) + y_spec.shape + return carry_spec, y_spec + + +@keras_export("keras.ops.scan") +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + """Scan a function over leading array axes while carrying along state. + + When the type of `xs` is an array type or `None`, and the type of `ys` is an + array type, the semantics of `scan()` are given roughly by this Python + implementation: + + ```python + def scan(f, init, xs, length=None): + if xs is None: + xs = [None] * length + carry = init + ys = [] + for x in xs: + carry, y = f(carry, x) + ys.append(y) + return carry, np.stack(ys) + ``` + + The loop-carried value `carry` (`init`) must hold a fixed shape and dtype + across all iterations. + + In TensorFlow, `y` must match `carry` in shape and dtype. This is not + required in other backends. + + Args: + f: Callable defines the logic for each loop iteration. This accepts two + arguments where the first is a value of the loop carry and the + second is a slice of `xs` along its leading axis. + This callable returns a pair where the first represents a new value + for the loop carry and the second represents a slice of the output. + init: The initial loop carry value. This can be a scalar, tensor, or any + nested structure. It must match the structure of the first element + returned by `f`. + xs: Optional value to scan along its leading axis. This can be a tensor + or any nested structure. If `xs` is not provided, you must specify + `length` to define the number of loop iterations. + Defaults to `None`. + length: Optional integer specifying the number of loop iterations. + If `length` is not provided, it defaults to the sizes of leading + axis of the arrays in `xs`. Defaults to `None`. + reverse: Optional boolean specifying whether to run the scan iteration + forward or in reverse, equivalent to reversing the leading axes of + the arrays in both `xs` and in `ys`. + unroll: Optional positive integer or boolean specifying how many scan + iterations to unroll within a single iteration of a loop. If an + integer is provided, it determines how many unrolled loop iterations + to run within a single rolled iteration of the loop. If a boolean is + provided, it will determine if the loop is completely unrolled + (`unroll=True`) or left completely unrolled (`unroll=False`). + Note that unrolling is only supported by JAX and TensorFlow + backends. + + Returns: + A pair where the first element represents the final loop carry value and + the second element represents the stacked outputs of `f` when scanned + over the leading axis of the inputs. + + Examples: + + >>> sum_fn = lambda c, x: (c + x, c + x) + >>> init = keras.ops.array(0) + >>> xs = keras.ops.array([1, 2, 3, 4, 5]) + >>> carry, result = ops.scan(sum_fn, init, xs) + >>> carry + 15 + >>> result + [1, 3, 6, 10, 15] + """ + if any_symbolic_tensors((init, xs)): + return Scan(reverse=reverse, unroll=unroll).symbolic_call( + f, init, xs, length + ) + return backend.core.scan( + f, init, xs, length, reverse=reverse, unroll=unroll + ) + + class Scatter(Operation): def call(self, indices, values, shape): return backend.core.scatter(indices, values, shape) diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index 3993b81637b..5b731230c08 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -19,6 +19,24 @@ class CoreOpsStaticShapeTest(testing.TestCase): + def test_scan(self): + def f(carry, xs): + xs = xs + carry + return carry, carry + + init = KerasTensor(()) + xs = KerasTensor((6,)) + carry, result = core.scan(f, init, xs) + self.assertEqual(carry.shape, ()) + self.assertEqual(result.shape, (6,)) + + def f2(carry, _): + return carry, carry + + carry, result = core.scan(f2, init, xs=None, length=3) + self.assertEqual(carry.shape, ()) + self.assertEqual(result.shape, (3,)) + def test_scatter(self): indices = KerasTensor((5, 2)) values = KerasTensor((5,)) @@ -85,6 +103,69 @@ def test_unstack(self): class CoreOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): + def test_scan(self): + # Test cumsum + def cumsum(carry, xs): + carry = carry + xs + return carry, carry + + init = np.array(0, dtype="float32") + xs = np.array([1, 2, 3, 4, 10, 20], dtype="float32") + carry, result = core.scan(cumsum, init, xs) + self.assertAllClose(carry, 40.0) + self.assertAllClose(result, ops.cumsum(xs)) + + # Test reverse=True + carry, result = core.scan(cumsum, init, xs, reverse=True) + self.assertAllClose(carry, 40.0) + self.assertAllClose(result, [40, 39, 37, 34, 30, 20]) + + # Test unroll + for unroll in (True, False, 2): + carry, result = core.scan(cumsum, init, xs, unroll=unroll) + self.assertAllClose(carry, 40.0) + self.assertAllClose(result, ops.cumsum(xs)) + + # Test xs is None + def fibonaccis(carry, _): + return (carry[1], carry[0] + carry[1]), None + + init = (np.array(0, dtype="float32"), np.array(1, dtype="float32")) + carry, _ = core.scan(fibonaccis, init, length=6) + self.assertAllClose(carry, [8, 13]) + + # Test nested init + if backend.backend() != "tensorflow": + # tensorflow doesn't support arbitrary shape/dtype of the output of + # `f`. It must be the same as `init`. + def multiply_two(carry, _): + value1 = carry["value1"] + value2 = carry["value2"] + return ( + {"value1": value1 * 2, "value2": value2 * 2}, + value1 * 2 + value2 * 2, + ) + + init = {"value1": 2.0, "value2": 3.0} + carry, result = core.scan(multiply_two, init, length=3) + self.assertAllClose(carry["value1"], 16) + self.assertAllClose(carry["value2"], 24) + self.assertAllClose(result, [10, 20, 40]) + + # Test nested xs + def reduce_add(carry, xs): + value1 = xs["value1"] + value2 = xs["value2"] + return carry, value1 + value2 + + init = np.array(0, dtype="float32") + xs = { + "value1": np.array([1, 2, 3], dtype="float32"), + "value2": np.array([10, 20, 30], dtype="float32"), + } + _, result = core.scan(reduce_add, init, xs) + self.assertAllClose(result, [11, 22, 33]) + def test_scatter(self): # Test 1D indices = np.array([[1], [3], [4], [7]]) @@ -633,15 +714,20 @@ def test_convert_to_tensor(self, x, dtype, expected_dtype): expected_dtype, ) - def test_convert_to_numpy(self): - x = ops.array([1, 2, 3], dtype="float32") - y = ops.convert_to_numpy(x) - self.assertIsInstance(y, np.ndarray) - # Test assignment -- should not fail. - y[0] = 1.0 - class CoreOpsCallsTests(testing.TestCase): + def test_scan_basic_call(self): + def cumsum(carry, xs): + carry = carry + xs + return carry, carry + + init = np.array(0, dtype="float32") + xs = np.array([1, 2, 3, 4, 10, 20], dtype="float32") + scan_op = core.Scan() + carry, result = scan_op.call(cumsum, init, xs, None) + self.assertAllClose(carry, 40.0) + self.assertAllClose(result, ops.cumsum(xs)) + def test_scatter_basic_call(self): indices = np.array([[1, 0], [0, 1]]) values = np.array([10, 20]) @@ -884,3 +970,34 @@ def test_cond_check_output_spec_tuple(self): (mock_spec,), (mock_spec, mock_spec_different) ) ) + + +class CoreOpsBehaviorTests(testing.TestCase): + def test_convert_to_numpy(self): + x = ops.array([1, 2, 3], dtype="float32") + y = ops.convert_to_numpy(x) + self.assertIsInstance(y, np.ndarray) + # Test assignment -- should not fail. + y[0] = 1.0 + + def test_scan_invalid_arguments(self): + def cumsum(carry, xs): + carry = carry + xs + return carry, carry + + init = np.array(0, dtype="float32") + xs = np.array([1, 2, 3, 4, 10, 20], dtype="float32") + + # Test non-callable + with self.assertRaisesRegex(TypeError, "should be a callable."): + core.scan(123, init, xs) + + # Test bad unroll + with self.assertRaisesRegex( + ValueError, "must be an positive integer or boolean." + ): + core.scan(cumsum, init, xs, unroll=-1) + + # Test both xs and length are None + with self.assertRaisesRegex(ValueError, "to scan over and"): + core.scan(cumsum, init, xs=None, length=None)