Skip to content

Commit

Permalink
fixing segment_max edge cases and add more exception handling (keras-…
Browse files Browse the repository at this point in the history
  • Loading branch information
haohuanw authored Jul 9, 2024
1 parent 297df87 commit 5c8363b
Show file tree
Hide file tree
Showing 5 changed files with 322 additions and 233 deletions.
44 changes: 17 additions & 27 deletions keras/src/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from keras.src.utils.module_utils import scipy


def segment_sum(data, segment_ids, num_segments=None, sorted=False):
def _segment_reduction_fn(
data, segment_ids, reduction_method, num_segments, sorted
):
if num_segments is None:
num_segments = np.amax(segment_ids) + 1

Expand All @@ -21,45 +23,33 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False):
num_segments # Replace first dimension (which corresponds to segments)
)

if sorted:
if reduction_method == np.maximum:
result = np.ones(data_shape, dtype=valid_data.dtype) * -np.inf
else:
result = np.zeros(data_shape, dtype=valid_data.dtype)
np.add.at(result, valid_segment_ids, valid_data)

if sorted:
reduction_method.at(result, valid_segment_ids, valid_data)
else:
sort_indices = np.argsort(valid_segment_ids)
sorted_segment_ids = valid_segment_ids[sort_indices]
sorted_data = valid_data[sort_indices]

result = np.zeros(data_shape, dtype=valid_data.dtype)
np.add.at(result, sorted_segment_ids, sorted_data)
reduction_method.at(result, sorted_segment_ids, sorted_data)

return result


def segment_max(data, segment_ids, num_segments=None, sorted=False):
if num_segments is None:
num_segments = np.amax(segment_ids) + 1

valid_indices = segment_ids >= 0 # Ignore segment_ids that are -1
valid_data = data[valid_indices]
valid_segment_ids = segment_ids[valid_indices]

data_shape = list(valid_data.shape)
data_shape[0] = (
num_segments # Replace first dimension (which corresponds to segments)
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
return _segment_reduction_fn(
data, segment_ids, np.add, num_segments, sorted
)

if sorted:
result = np.zeros(data_shape, dtype=valid_data.dtype)
np.maximum.at(result, valid_segment_ids, valid_data)
else:
sort_indices = np.argsort(valid_segment_ids)
sorted_segment_ids = valid_segment_ids[sort_indices]
sorted_data = valid_data[sort_indices]

result = np.zeros(data_shape, dtype=valid_data.dtype)
np.maximum.at(result, sorted_segment_ids, sorted_data)

return result
def segment_max(data, segment_ids, num_segments=None, sorted=False):
return _segment_reduction_fn(
data, segment_ids, np.maximum, num_segments, sorted
)


def top_k(x, k, sorted=False):
Expand Down
12 changes: 12 additions & 0 deletions keras/src/backend/tensorflow/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@

def segment_sum(data, segment_ids, num_segments=None, sorted=False):
if sorted:
if num_segments is not None:
raise ValueError(
"Argument `num_segments` cannot be set when sorted is True "
"when using the tensorflow backend."
f"Received: num_segments={num_segments}, sorted={sorted}."
)
return tf.math.segment_sum(data, segment_ids)
else:
if num_segments is None:
Expand All @@ -19,6 +25,12 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False):

def segment_max(data, segment_ids, num_segments=None, sorted=False):
if sorted:
if num_segments is not None:
raise ValueError(
"Argument `num_segments` cannot be set when sorted is True "
"when using the tensorflow backend."
f"Received: num_segments={num_segments}, sorted={sorted}."
)
return tf.math.segment_max(data, segment_ids)
else:
if num_segments is None:
Expand Down
51 changes: 14 additions & 37 deletions keras/src/backend/torch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from keras.src.backend.torch.numpy import pad


def segment_sum(data, segment_ids, num_segments=None, **kwargs):
data = convert_to_tensor(data)
segment_ids = convert_to_tensor(segment_ids)
def _segment_reduction_fn(data, segment_ids, reduction_method, num_segments):
num_repeats = torch.prod(
torch.tensor(data.shape[1:], device=get_device())
).long()
Expand All @@ -39,8 +37,13 @@ def segment_sum(data, segment_ids, num_segments=None, **kwargs):
# Add one more dimension to the result shape with the "+1".
shape = (num_segments + 1,) + tuple(data.shape[1:])

result = torch.zeros(*shape, device=get_device()).scatter_add(
0, segment_ids, data.float()
if reduction_method == "amax":
result = torch.ones(*shape, device=get_device()) * -float("Inf")
else:
result = torch.zeros(*shape, device=get_device())

result = result.scatter_reduce(
0, segment_ids, data.float(), reduction_method
)

# Removing the extra dimension.
Expand All @@ -49,42 +52,16 @@ def segment_sum(data, segment_ids, num_segments=None, **kwargs):
return result.type(data.dtype)


def segment_max(data, segment_ids, num_segments=None, **kwargs):
def segment_sum(data, segment_ids, num_segments=None, **kwargs):
data = convert_to_tensor(data)
segment_ids = convert_to_tensor(segment_ids)
num_repeats = torch.prod(
torch.tensor(data.shape[1:], device=get_device())
).long()
# To use `scatter_reduce` in torch, we need to replicate `segment_ids` into
# the shape of `data`.
segment_ids = (
segment_ids.repeat_interleave(num_repeats)
.view(*data.shape)
.type(torch.int64)
)
num_segments = num_segments or len(torch.unique(segment_ids))

# .scatter_reduce does not support -1 in the indices.
# Add all out-of-bound indices value to an extra dimension after
# num_segments, which is removed before returning the result.

# Replacing the out-of-bound indices.
segment_ids = torch.where(segment_ids >= 0, segment_ids, num_segments)
segment_ids = torch.where(
segment_ids < num_segments, segment_ids, num_segments
)

# Add one more dimension to the result shape with the "+1".
shape = (num_segments + 1,) + tuple(data.shape[1:])

result = torch.zeros(*shape, device=get_device()).scatter_reduce(
0, segment_ids, data.float(), "amax"
)
return _segment_reduction_fn(data, segment_ids, "sum", num_segments)

# Removing the extra dimension.
result = result[:-1, ...]

return result.type(data.dtype)
def segment_max(data, segment_ids, num_segments=None, **kwargs):
data = convert_to_tensor(data)
segment_ids = convert_to_tensor(segment_ids)
return _segment_reduction_fn(data, segment_ids, "amax", num_segments)


def top_k(x, k, sorted=True):
Expand Down
57 changes: 39 additions & 18 deletions keras/src/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,44 @@
from keras.src.ops.operation_utils import reduce_shape


class SegmentSum(Operation):
def _segment_reduce_validation(data, segment_ids):
data_shape = data.shape
segment_ids_shape = segment_ids.shape
if len(segment_ids_shape) > 1:
raise ValueError(
"Argument `segment_ids` should be an 1-D vector, got shape: "
f"{len(segment_ids_shape)}. Consider either flatten input with "
"segment_ids.reshape((-1)) and "
"data.reshape((-1, ) + data.shape[len(segment_ids.shape):]) or "
"vectorize with vmap."
)
if (
segment_ids_shape[0] is not None
and data_shape[0] is not None
and segment_ids_shape[0] != data_shape[0]
):
raise ValueError(
"Argument `segment_ids` and `data` should have same leading "
f"dimension. Got {segment_ids_shape} v.s. "
f"{data_shape}."
)


class SegmentReduction(Operation):
def __init__(self, num_segments=None, sorted=False):
super().__init__()
self.num_segments = num_segments
self.sorted = sorted

def compute_output_spec(self, data, segment_ids):
num_segments = self.num_segments
output_shape = (num_segments,) + tuple(data.shape[1:])
def compute_output_spec(self, data, _):
output_shape = (self.num_segments,) + tuple(data.shape[1:])
return KerasTensor(shape=output_shape, dtype=data.dtype)


class SegmentSum(SegmentReduction):

def call(self, data, segment_ids):
_segment_reduce_validation(data, segment_ids)
return backend.math.segment_sum(
data,
segment_ids,
Expand All @@ -34,8 +60,9 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False):
Args:
data: Input tensor.
segment_ids: A 1-D tensor containing segment indices for each
element in `data`.
segment_ids: A N-D tensor containing segment indices for each
element in `data`. Num dims for segment ids should be strictly
smaller or equal to number of dims in data.
num_segments: An integer representing the total number of
segments. If not specified, it is inferred from the maximum
value in `segment_ids`.
Expand All @@ -54,25 +81,18 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False):
>>> keras.ops.segment_sum(data, segment_ids,num_segments)
array([3, 30, 300], dtype=int32)
"""
_segment_reduce_validation(data, segment_ids)
if any_symbolic_tensors((data,)):
return SegmentSum(num_segments, sorted).symbolic_call(data, segment_ids)
return backend.math.segment_sum(
data, segment_ids, num_segments=num_segments, sorted=sorted
)


class SegmentMax(Operation):
def __init__(self, num_segments=None, sorted=False):
super().__init__()
self.num_segments = num_segments
self.sorted = sorted

def compute_output_spec(self, data, segment_ids):
num_segments = self.num_segments
output_shape = (num_segments,) + tuple(data.shape[1:])
return KerasTensor(shape=output_shape, dtype=data.dtype)
class SegmentMax(SegmentReduction):

def call(self, data, segment_ids):
_segment_reduce_validation(data, segment_ids)
return backend.math.segment_max(
data,
segment_ids,
Expand All @@ -87,8 +107,8 @@ def segment_max(data, segment_ids, num_segments=None, sorted=False):
Args:
data: Input tensor.
segment_ids: A 1-D tensor containing segment indices for each
element in `data`.
segment_ids: A N-D tensor containing segment indices for each
element in `data`. data.shape[:len(segment_ids.shape)] should match.
num_segments: An integer representing the total number of
segments. If not specified, it is inferred from the maximum
value in `segment_ids`.
Expand All @@ -107,6 +127,7 @@ def segment_max(data, segment_ids, num_segments=None, sorted=False):
>>> keras.ops.segment_max(data, segment_ids, num_segments)
array([2, 20, 200], dtype=int32)
"""
_segment_reduce_validation(data, segment_ids)
if any_symbolic_tensors((data,)):
return SegmentMax(num_segments, sorted).symbolic_call(data, segment_ids)
return backend.math.segment_max(
Expand Down
Loading

0 comments on commit 5c8363b

Please sign in to comment.