Skip to content

Commit

Permalink
Fix argmax/argmin keepdims with defined axis in TF
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Apr 28, 2024
1 parent 0f3bd52 commit b41f687
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
3 changes: 2 additions & 1 deletion keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras.src.backend.tensorflow import sparse
from keras.src.backend.tensorflow.core import cast
from keras.src.backend.tensorflow.core import convert_to_tensor
from keras.src.backend.tensorflow.core import shape as shape_op


@sparse.elementwise_binary_union(tf.sparse.add)
Expand Down Expand Up @@ -756,7 +757,7 @@ def _keepdims(x, y, axis):
if axis is None:
shape = [1 for _ in range(len(x.shape))]
else:
shape = [tf.shape[i] for i in range(len(x.shape))]
shape = list(shape_op(x))
for axis in tree.flatten(axis):
shape[axis] = 1
y = tf.reshape(y, shape)
Expand Down
6 changes: 5 additions & 1 deletion keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3105,9 +3105,13 @@ def test_arctanh(self):
self.assertAllClose(knp.Arctanh()(x), np.arctanh(x))

def test_argmax(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
x = np.array([[1, 2, 3], [3, 2, 1], [4, 5, 6]])
self.assertAllClose(knp.argmax(x), np.argmax(x))
self.assertAllClose(knp.argmax(x, axis=1), np.argmax(x, axis=1))
self.assertAllClose(
knp.argmax(x, axis=1, keepdims=True),
np.argmax(x, axis=1, keepdims=True),
)
self.assertAllClose(
knp.argmax(x, keepdims=True), np.argmax(x, keepdims=True)
)
Expand Down

0 comments on commit b41f687

Please sign in to comment.