diff --git a/keras/src/layers/activations/softmax.py b/keras/src/layers/activations/softmax.py index c1fee581a89..3fbcbe5538d 100644 --- a/keras/src/layers/activations/softmax.py +++ b/keras/src/layers/activations/softmax.py @@ -50,6 +50,12 @@ def __init__(self, axis=-1, **kwargs): def call(self, inputs, mask=None): if mask is not None: + if mask.shape != inputs.shape: + raise ValueError( + "`mask` and `inputs` must have same shape. " + f"Got inputs shape: {inputs.shape} and " + f"mask shape: {mask.shape}" + ) adder = ( 1.0 - backend.cast(mask, inputs.dtype) ) * _large_negative_number(inputs.dtype)