Skip to content

Commit

Permalink
Allow op.select to accept tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed May 24, 2024
1 parent 479a0e0 commit b476304
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
2 changes: 1 addition & 1 deletion keras/src/legacy/preprocessing/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def hash_function(w):


@keras_export("keras._legacy.preprocessing.text.Tokenizer")
class Tokenizer(object):
class Tokenizer:
"""DEPRECATED."""

def __init__(
Expand Down
6 changes: 5 additions & 1 deletion keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6125,12 +6125,16 @@ def select(condlist, choicelist, default=0):
# Returns: tensor([0, 1, 2, 42, 16, 25])
```
"""
if not isinstance(condlist, list) or not isinstance(choicelist, list):
if not isinstance(condlist, (list, tuple)) or not isinstance(
choicelist, (list, tuple)
):
raise ValueError(
"condlist and choicelist must be lists. Received: "
f"type(condlist) = {type(condlist)}, "
f"type(choicelist) = {type(choicelist)}"
)
condlist = list(condlist)
choicelist = list(choicelist)
if not condlist or not choicelist:
raise ValueError(
"condlist and choicelist must not be empty. Received: "
Expand Down
7 changes: 7 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4265,6 +4265,13 @@ def test_select(self):
y = knp.select(condlist, choicelist, 42)
self.assertAllClose(y, [0, 1, 2, 42, 16, 25])

# Test with tuples
condlist = (x < 3, x > 3)
choicelist = (x, x**2)
y = knp.select(condlist, choicelist, 42)
self.assertAllClose(y, [0, 1, 2, 42, 16, 25])

# Test with symbolic tensors
x = backend.KerasTensor((6,))
condlist = [x < 3, x > 3]
choicelist = [x, x**2]
Expand Down
2 changes: 1 addition & 1 deletion keras/src/trainers/data_adapters/data_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class DataAdapter(object):
class DataAdapter:
"""Base class for input data adapters.
The purpose of a DataAdapter is to provide a unfied interface to
Expand Down

0 comments on commit b476304

Please sign in to comment.