Skip to content

Commit

Permalink
Add scan op (#19681)
Browse files Browse the repository at this point in the history
* Add `scan`

* Fix lint

* Increase test coverage

* Increase test coverage

* Replace `TypeError` with `ValueError` for invalid `unroll`
  • Loading branch information
james77777778 authored May 8, 2024
1 parent 10c27c0 commit 43e5155
Show file tree
Hide file tree
Showing 8 changed files with 465 additions and 7 deletions.
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from keras.src.ops.core import custom_gradient
from keras.src.ops.core import fori_loop
from keras.src.ops.core import is_tensor
from keras.src.ops.core import scan
from keras.src.ops.core import scatter
from keras.src.ops.core import scatter_update
from keras.src.ops.core import shape
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from keras.src.ops.core import custom_gradient
from keras.src.ops.core import fori_loop
from keras.src.ops.core import is_tensor
from keras.src.ops.core import scan
from keras.src.ops.core import scatter
from keras.src.ops.core import scatter_update
from keras.src.ops.core import shape
Expand Down
12 changes: 12 additions & 0 deletions keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ def vectorized_map(function, elements):
return jax.vmap(function)(elements)


def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
if not isinstance(unroll, bool):
if not isinstance(unroll, int) or unroll < 1:
raise ValueError(
"`unroll` must be an positive integer or boolean. "
f"Received: unroll={unroll}"
)
return jax.lax.scan(
f, init=init, xs=xs, length=length, reverse=reverse, unroll=unroll
)


def scatter(indices, values, shape):
zeros = jnp.zeros(shape, values.dtype)
key = tuple(jnp.moveaxis(indices, -1, 0))
Expand Down
49 changes: 49 additions & 0 deletions keras/src/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,55 @@ def convert_numpy_to_keras_tensor(x):
return output_spec


def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
# Ref: jax.lax.scan
if not callable(f):
raise TypeError(f"`f` should be a callable. Received: f={f}")
if not isinstance(unroll, bool):
if not isinstance(unroll, int) or unroll < 1:
raise ValueError(
"`unroll` must be an positive integer or boolean. "
f"Received: unroll={unroll}"
)
if xs is None and length is None:
raise ValueError("Got no `xs` to scan over and `length` not provided.")

input_is_sequence = tree.is_nested(xs)
output_is_sequence = tree.is_nested(init)

def pack_input(x):
return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]

def pack_output(x):
return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]

if xs is None:
xs_flat = []
n = int(length)
else:
xs_flat = tree.flatten(xs)
xs_flat = [convert_to_tensor(elem) for elem in xs_flat]
n = int(length) if length is not None else shape(xs_flat[0])[0]

init_flat = tree.flatten(init)
init_flat = [convert_to_tensor(init) for init in init_flat]
init = pack_output(init_flat)
dummy_y = [np.zeros_like(init) for init in init_flat]

carry = init
ys = []
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(n)):
xs_slice = [x[i] for x in xs_flat]
packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None
carry, y = f(carry, packed_xs)
ys.append(y if y is not None else dummy_y)
stacked_y = tree.map_structure(
lambda *ys: np.stack(ys), *maybe_reversed(ys)
)
return carry, stacked_y


def scatter(indices, values, shape):
indices = convert_to_tensor(indices)
values = convert_to_tensor(values)
Expand Down
121 changes: 121 additions & 0 deletions keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,127 @@ def vectorized_map(function, elements):
return tf.vectorized_map(function, elements)


def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
# We have reimplemented `scan` to match the behavior of `jax.lax.scan`
# Ref: tf.scan, jax.lax.scan
if not callable(f):
raise TypeError(f"`f` should be a callable. Received: f={f}")
if not isinstance(unroll, bool):
if not isinstance(unroll, int) or unroll < 1:
raise ValueError(
"`unroll` must be an positive integer or boolean. "
f"Received: unroll={unroll}"
)
if xs is None and length is None:
raise ValueError("Got no `xs` to scan over and `length` not provided.")

input_is_sequence = tree.is_nested(xs)
output_is_sequence = tree.is_nested(init)

def pack_input(x):
return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]

def pack_output(x):
return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]

if xs is None:
xs_flat = []
n = int(length)
else:
# xs_flat = flatten_input(xs)
xs_flat = tree.flatten(xs)
xs_flat = [tf.convert_to_tensor(elem) for elem in xs_flat]
n = int(length) if length is not None else tf.shape(xs_flat[0])[0]

# TensorArrays are always flat
xs_array = [
tf.TensorArray(
dtype=x.dtype,
size=n,
dynamic_size=False,
element_shape=x.shape[1:],
infer_shape=True,
)
for x in xs_flat
]
xs_array = [x_a.unstack(x) for x_a, x in zip(xs_array, xs_flat)]

init_flat = tree.flatten(init)
carry_flat = [tf.convert_to_tensor(init) for init in init_flat]

# Store the intermediate values
# Note: there is a constraint that the output of `f` must have the same
# shape and dtype as carry (`init`).
ys_array = [
tf.TensorArray(
dtype=carry.dtype,
size=n,
dynamic_size=False,
element_shape=carry.shape,
infer_shape=True,
)
for carry in carry_flat
]
carry_array = [
tf.TensorArray(
dtype=carry.dtype,
size=1,
dynamic_size=False,
clear_after_read=False,
element_shape=carry.shape,
infer_shape=True,
)
for carry in carry_flat
]
carry_array = [
carry.write(0, c) for (carry, c) in zip(carry_array, carry_flat)
]

def loop_body(i, carry_array, ys_array):
packed_xs = (
pack_input([xs.read(i) for xs in xs_array])
if len(xs_array) > 0
else None
)
packed_carry = pack_output([carry.read(0) for carry in carry_array])

carry, ys = f(packed_carry, packed_xs)

if ys is not None:
flat_ys = tree.flatten(ys)
ys_array = [ys.write(i, v) for (ys, v) in zip(ys_array, flat_ys)]
if carry is not None:
flat_carry = tree.flatten(carry)
carry_array = [
carry.write(0, v) for (carry, v) in zip(carry_array, flat_carry)
]
next_i = i + 1 if not reverse else i - 1
return (next_i, carry_array, ys_array)

if isinstance(unroll, bool):
unroll = max(n, 1) if unroll else 1

_, carry_array, ys_array = tf.while_loop(
lambda i, _1, _2: i >= 0 if reverse else i < n,
loop_body,
(n - 1 if reverse else 0, carry_array, ys_array),
parallel_iterations=unroll,
)

ys_flat = [ys.stack() for ys in ys_array]
carry_flat = [carry.read(0) for carry in carry_array]
if xs is not None:
n_static = xs_flat[0].get_shape().with_rank_at_least(1)[0]
if not isinstance(n_static, int):
for x in xs_flat[1:]:
n_static.assert_is_compatible_with(
x.get_shape().with_rank_at_least(1)[0]
)
for r in ys_flat:
r.set_shape(tf.TensorShape(n_static).concatenate(r.get_shape()[1:]))
return pack_output(carry_flat), pack_output(ys_flat)


def scatter(indices, values, shape):
return tf.scatter_nd(indices, values, shape)

Expand Down
49 changes: 49 additions & 0 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,55 @@ def vectorized_map(function, elements):
return torch.vmap(function)(elements)


def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
# Ref: jax.lax.scan
if not callable(f):
raise TypeError(f"`f` should be a callable. Received: f={f}")
if not isinstance(unroll, bool):
if not isinstance(unroll, int) or unroll < 1:
raise ValueError(
"`unroll` must be an positive integer or boolean. "
f"Received: unroll={unroll}"
)
if xs is None and length is None:
raise ValueError("Got no `xs` to scan over and `length` not provided.")

input_is_sequence = tree.is_nested(xs)
output_is_sequence = tree.is_nested(init)

def pack_input(x):
return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]

def pack_output(x):
return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]

if xs is None:
xs_flat = []
n = int(length)
else:
xs_flat = tree.flatten(xs)
xs_flat = [convert_to_tensor(elem) for elem in xs_flat]
n = int(length) if length is not None else shape(xs_flat[0])[0]

init_flat = tree.flatten(init)
init_flat = [convert_to_tensor(init) for init in init_flat]
init = pack_output(init_flat)
dummy_y = [torch.zeros_like(init) for init in init_flat]

carry = init
ys = []
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(n)):
xs_slice = [x[i] for x in xs_flat]
packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None
carry, y = f(carry, packed_xs)
ys.append(y if y is not None else dummy_y)
stacked_y = tree.map_structure(
lambda *ys: torch.stack(ys), *maybe_reversed(ys)
)
return carry, stacked_y


def scatter(indices, values, shape):
indices = convert_to_tensor(indices)
values = convert_to_tensor(values)
Expand Down
108 changes: 108 additions & 0 deletions keras/src/ops/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""
scan
scatter
scatter_update
slice
Expand All @@ -25,6 +26,113 @@
from keras.src.utils import traceback_utils


class Scan(Operation):
def __init__(self, reverse=False, unroll=1):
super().__init__()
self.reverse = reverse
self.unroll = unroll

def call(self, f, init, xs, length):
return backend.core.scan(
f, init, xs, length, reverse=self.reverse, unroll=self.unroll
)

def compute_output_spec(self, f, init, xs, length):
if xs is None:
n = int(length)
x = None
else:
n = (
int(length)
if length is not None
else tree.flatten(xs)[0].shape[0]
)
x = xs[0]

carry_spec, y_spec = backend.compute_output_spec(f, init, x)
y_spec.shape = (n,) + y_spec.shape
return carry_spec, y_spec


@keras_export("keras.ops.scan")
def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
"""Scan a function over leading array axes while carrying along state.
When the type of `xs` is an array type or `None`, and the type of `ys` is an
array type, the semantics of `scan()` are given roughly by this Python
implementation:
```python
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
```
The loop-carried value `carry` (`init`) must hold a fixed shape and dtype
across all iterations.
In TensorFlow, `y` must match `carry` in shape and dtype. This is not
required in other backends.
Args:
f: Callable defines the logic for each loop iteration. This accepts two
arguments where the first is a value of the loop carry and the
second is a slice of `xs` along its leading axis.
This callable returns a pair where the first represents a new value
for the loop carry and the second represents a slice of the output.
init: The initial loop carry value. This can be a scalar, tensor, or any
nested structure. It must match the structure of the first element
returned by `f`.
xs: Optional value to scan along its leading axis. This can be a tensor
or any nested structure. If `xs` is not provided, you must specify
`length` to define the number of loop iterations.
Defaults to `None`.
length: Optional integer specifying the number of loop iterations.
If `length` is not provided, it defaults to the sizes of leading
axis of the arrays in `xs`. Defaults to `None`.
reverse: Optional boolean specifying whether to run the scan iteration
forward or in reverse, equivalent to reversing the leading axes of
the arrays in both `xs` and in `ys`.
unroll: Optional positive integer or boolean specifying how many scan
iterations to unroll within a single iteration of a loop. If an
integer is provided, it determines how many unrolled loop iterations
to run within a single rolled iteration of the loop. If a boolean is
provided, it will determine if the loop is completely unrolled
(`unroll=True`) or left completely unrolled (`unroll=False`).
Note that unrolling is only supported by JAX and TensorFlow
backends.
Returns:
A pair where the first element represents the final loop carry value and
the second element represents the stacked outputs of `f` when scanned
over the leading axis of the inputs.
Examples:
>>> sum_fn = lambda c, x: (c + x, c + x)
>>> init = keras.ops.array(0)
>>> xs = keras.ops.array([1, 2, 3, 4, 5])
>>> carry, result = ops.scan(sum_fn, init, xs)
>>> carry
15
>>> result
[1, 3, 6, 10, 15]
"""
if any_symbolic_tensors((init, xs)):
return Scan(reverse=reverse, unroll=unroll).symbolic_call(
f, init, xs, length
)
return backend.core.scan(
f, init, xs, length, reverse=reverse, unroll=unroll
)


class Scatter(Operation):
def call(self, indices, values, shape):
return backend.core.scatter(indices, values, shape)
Expand Down
Loading

0 comments on commit 43e5155

Please sign in to comment.