Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scan op #19681

Merged
merged 7 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from keras.src.ops.core import custom_gradient
from keras.src.ops.core import fori_loop
from keras.src.ops.core import is_tensor
from keras.src.ops.core import scan
from keras.src.ops.core import scatter
from keras.src.ops.core import scatter_update
from keras.src.ops.core import shape
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from keras.src.ops.core import custom_gradient
from keras.src.ops.core import fori_loop
from keras.src.ops.core import is_tensor
from keras.src.ops.core import scan
from keras.src.ops.core import scatter
from keras.src.ops.core import scatter_update
from keras.src.ops.core import shape
Expand Down
12 changes: 12 additions & 0 deletions keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ def vectorized_map(function, elements):
return jax.vmap(function)(elements)


def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
if not isinstance(unroll, bool):
if not isinstance(unroll, int) or unroll < 1:
raise ValueError(
"`unroll` must be an positive integer or boolean. "
f"Received: unroll={unroll}"
)
return jax.lax.scan(
f, init=init, xs=xs, length=length, reverse=reverse, unroll=unroll
)


def scatter(indices, values, shape):
zeros = jnp.zeros(shape, values.dtype)
key = tuple(jnp.moveaxis(indices, -1, 0))
Expand Down
49 changes: 49 additions & 0 deletions keras/src/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,55 @@ def convert_numpy_to_keras_tensor(x):
return output_spec


def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
# Ref: jax.lax.scan
if not callable(f):
raise TypeError(f"`f` should be a callable. Received: f={f}")
if not isinstance(unroll, bool):
if not isinstance(unroll, int) or unroll < 1:
raise ValueError(
"`unroll` must be an positive integer or boolean. "
f"Received: unroll={unroll}"
)
if xs is None and length is None:
raise ValueError("Got no `xs` to scan over and `length` not provided.")

input_is_sequence = tree.is_nested(xs)
output_is_sequence = tree.is_nested(init)

def pack_input(x):
return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]

def pack_output(x):
return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]

if xs is None:
xs_flat = []
n = int(length)
else:
xs_flat = tree.flatten(xs)
xs_flat = [convert_to_tensor(elem) for elem in xs_flat]
n = int(length) if length is not None else shape(xs_flat[0])[0]

init_flat = tree.flatten(init)
init_flat = [convert_to_tensor(init) for init in init_flat]
init = pack_output(init_flat)
dummy_y = [np.zeros_like(init) for init in init_flat]

carry = init
ys = []
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(n)):
xs_slice = [x[i] for x in xs_flat]
packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None
carry, y = f(carry, packed_xs)
ys.append(y if y is not None else dummy_y)
stacked_y = tree.map_structure(
lambda *ys: np.stack(ys), *maybe_reversed(ys)
)
return carry, stacked_y


def scatter(indices, values, shape):
indices = convert_to_tensor(indices)
values = convert_to_tensor(values)
Expand Down
121 changes: 121 additions & 0 deletions keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,127 @@ def vectorized_map(function, elements):
return tf.vectorized_map(function, elements)


def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
# We have reimplemented `scan` to match the behavior of `jax.lax.scan`
# Ref: tf.scan, jax.lax.scan
if not callable(f):
raise TypeError(f"`f` should be a callable. Received: f={f}")
if not isinstance(unroll, bool):
if not isinstance(unroll, int) or unroll < 1:
raise ValueError(
"`unroll` must be an positive integer or boolean. "
f"Received: unroll={unroll}"
)
if xs is None and length is None:
raise ValueError("Got no `xs` to scan over and `length` not provided.")

input_is_sequence = tree.is_nested(xs)
output_is_sequence = tree.is_nested(init)

def pack_input(x):
return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]

def pack_output(x):
return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]

if xs is None:
xs_flat = []
n = int(length)
else:
# xs_flat = flatten_input(xs)
xs_flat = tree.flatten(xs)
xs_flat = [tf.convert_to_tensor(elem) for elem in xs_flat]
n = int(length) if length is not None else tf.shape(xs_flat[0])[0]

# TensorArrays are always flat
xs_array = [
tf.TensorArray(
dtype=x.dtype,
size=n,
dynamic_size=False,
element_shape=x.shape[1:],
infer_shape=True,
)
for x in xs_flat
]
xs_array = [x_a.unstack(x) for x_a, x in zip(xs_array, xs_flat)]

init_flat = tree.flatten(init)
carry_flat = [tf.convert_to_tensor(init) for init in init_flat]

# Store the intermediate values
# Note: there is a constraint that the output of `f` must have the same
# shape and dtype as carry (`init`).
ys_array = [
tf.TensorArray(
dtype=carry.dtype,
size=n,
dynamic_size=False,
element_shape=carry.shape,
infer_shape=True,
)
for carry in carry_flat
]
carry_array = [
tf.TensorArray(
dtype=carry.dtype,
size=1,
dynamic_size=False,
clear_after_read=False,
element_shape=carry.shape,
infer_shape=True,
)
for carry in carry_flat
]
carry_array = [
carry.write(0, c) for (carry, c) in zip(carry_array, carry_flat)
]

def loop_body(i, carry_array, ys_array):
packed_xs = (
pack_input([xs.read(i) for xs in xs_array])
if len(xs_array) > 0
else None
)
packed_carry = pack_output([carry.read(0) for carry in carry_array])

carry, ys = f(packed_carry, packed_xs)

if ys is not None:
flat_ys = tree.flatten(ys)
ys_array = [ys.write(i, v) for (ys, v) in zip(ys_array, flat_ys)]
if carry is not None:
flat_carry = tree.flatten(carry)
carry_array = [
carry.write(0, v) for (carry, v) in zip(carry_array, flat_carry)
]
next_i = i + 1 if not reverse else i - 1
return (next_i, carry_array, ys_array)

if isinstance(unroll, bool):
unroll = max(n, 1) if unroll else 1

_, carry_array, ys_array = tf.while_loop(
lambda i, _1, _2: i >= 0 if reverse else i < n,
loop_body,
(n - 1 if reverse else 0, carry_array, ys_array),
parallel_iterations=unroll,
)

ys_flat = [ys.stack() for ys in ys_array]
carry_flat = [carry.read(0) for carry in carry_array]
if xs is not None:
n_static = xs_flat[0].get_shape().with_rank_at_least(1)[0]
if not isinstance(n_static, int):
for x in xs_flat[1:]:
n_static.assert_is_compatible_with(
x.get_shape().with_rank_at_least(1)[0]
)
for r in ys_flat:
r.set_shape(tf.TensorShape(n_static).concatenate(r.get_shape()[1:]))
return pack_output(carry_flat), pack_output(ys_flat)


def scatter(indices, values, shape):
return tf.scatter_nd(indices, values, shape)

Expand Down
49 changes: 49 additions & 0 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,55 @@ def vectorized_map(function, elements):
return torch.vmap(function)(elements)


def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
# Ref: jax.lax.scan
if not callable(f):
raise TypeError(f"`f` should be a callable. Received: f={f}")
if not isinstance(unroll, bool):
if not isinstance(unroll, int) or unroll < 1:
raise ValueError(
"`unroll` must be an positive integer or boolean. "
f"Received: unroll={unroll}"
)
if xs is None and length is None:
raise ValueError("Got no `xs` to scan over and `length` not provided.")

input_is_sequence = tree.is_nested(xs)
output_is_sequence = tree.is_nested(init)

def pack_input(x):
return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]

def pack_output(x):
return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]

if xs is None:
xs_flat = []
n = int(length)
else:
xs_flat = tree.flatten(xs)
xs_flat = [convert_to_tensor(elem) for elem in xs_flat]
n = int(length) if length is not None else shape(xs_flat[0])[0]

init_flat = tree.flatten(init)
init_flat = [convert_to_tensor(init) for init in init_flat]
init = pack_output(init_flat)
dummy_y = [torch.zeros_like(init) for init in init_flat]

carry = init
ys = []
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(n)):
xs_slice = [x[i] for x in xs_flat]
packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None
carry, y = f(carry, packed_xs)
ys.append(y if y is not None else dummy_y)
stacked_y = tree.map_structure(
lambda *ys: torch.stack(ys), *maybe_reversed(ys)
)
return carry, stacked_y


def scatter(indices, values, shape):
indices = convert_to_tensor(indices)
values = convert_to_tensor(values)
Expand Down
108 changes: 108 additions & 0 deletions keras/src/ops/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""
scan
scatter
scatter_update
slice
Expand All @@ -25,6 +26,113 @@
from keras.src.utils import traceback_utils


class Scan(Operation):
def __init__(self, reverse=False, unroll=1):
super().__init__()
self.reverse = reverse
self.unroll = unroll

def call(self, f, init, xs, length):
return backend.core.scan(
f, init, xs, length, reverse=self.reverse, unroll=self.unroll
)

def compute_output_spec(self, f, init, xs, length):
if xs is None:
n = int(length)
x = None
else:
n = (
int(length)
if length is not None
else tree.flatten(xs)[0].shape[0]
)
x = xs[0]

carry_spec, y_spec = backend.compute_output_spec(f, init, x)
y_spec.shape = (n,) + y_spec.shape
return carry_spec, y_spec


@keras_export("keras.ops.scan")
def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
"""Scan a function over leading array axes while carrying along state.

When the type of `xs` is an array type or `None`, and the type of `ys` is an
array type, the semantics of `scan()` are given roughly by this Python
implementation:

```python
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
```

The loop-carried value `carry` (`init`) must hold a fixed shape and dtype
across all iterations.

In TensorFlow, `y` must match `carry` in shape and dtype. This is not
required in other backends.

Args:
f: Callable defines the logic for each loop iteration. This accepts two
arguments where the first is a value of the loop carry and the
second is a slice of `xs` along its leading axis.
This callable returns a pair where the first represents a new value
for the loop carry and the second represents a slice of the output.
init: The initial loop carry value. This can be a scalar, tensor, or any
nested structure. It must match the structure of the first element
returned by `f`.
xs: Optional value to scan along its leading axis. This can be a tensor
or any nested structure. If `xs` is not provided, you must specify
`length` to define the number of loop iterations.
Defaults to `None`.
length: Optional integer specifying the number of loop iterations.
If `length` is not provided, it defaults to the sizes of leading
axis of the arrays in `xs`. Defaults to `None`.
reverse: Optional boolean specifying whether to run the scan iteration
forward or in reverse, equivalent to reversing the leading axes of
the arrays in both `xs` and in `ys`.
unroll: Optional positive integer or boolean specifying how many scan
iterations to unroll within a single iteration of a loop. If an
integer is provided, it determines how many unrolled loop iterations
to run within a single rolled iteration of the loop. If a boolean is
provided, it will determine if the loop is completely unrolled
(`unroll=True`) or left completely unrolled (`unroll=False`).
Note that unrolling is only supported by JAX and TensorFlow
backends.

Returns:
A pair where the first element represents the final loop carry value and
the second element represents the stacked outputs of `f` when scanned
over the leading axis of the inputs.

Examples:

>>> sum_fn = lambda c, x: (c + x, c + x)
>>> init = keras.ops.array(0)
>>> xs = keras.ops.array([1, 2, 3, 4, 5])
>>> carry, result = ops.scan(sum_fn, init, xs)
>>> carry
15
>>> result
[1, 3, 6, 10, 15]
"""
if any_symbolic_tensors((init, xs)):
return Scan(reverse=reverse, unroll=unroll).symbolic_call(
f, init, xs, length
)
return backend.core.scan(
f, init, xs, length, reverse=reverse, unroll=unroll
)


class Scatter(Operation):
def call(self, indices, values, shape):
return backend.core.scatter(indices, values, shape)
Expand Down
Loading