Skip to content

Commit

Permalink
StringLookup layer can now take tf.SparseTensors as input. (keras-tea…
Browse files Browse the repository at this point in the history
  • Loading branch information
hertschuh authored May 27, 2024
1 parent 510d406 commit a243d91
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion keras/src/layers/preprocessing/string_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def get_config(self):
return {**base_config, **config}

def call(self, inputs):
if isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
if isinstance(inputs, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)):
tf_inputs = True
else:
tf_inputs = False
Expand Down
21 changes: 21 additions & 0 deletions keras/src/layers/preprocessing/string_lookup_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest
from tensorflow import data as tf_data

from keras.src import backend
Expand Down Expand Up @@ -38,6 +39,26 @@ def test_fixed_vocabulary(self):
self.assertTrue(backend.is_tensor(output))
self.assertAllClose(output, np.array([2, 3, 0]))

@pytest.mark.skipif(
not backend.backend() == "tensorflow", reason="Requires tf.SparseTensor"
)
def test_sparse_inputs(self):
import tensorflow as tf

layer = layers.StringLookup(
output_mode="int",
vocabulary=["a", "b", "c"],
)
input_data = tf.SparseTensor(
indices=[[0, 0], [1, 1], [2, 2]],
values=["b", "c", "d"],
dense_shape=(3, 3),
)
output = layer(input_data)
self.assertIsInstance(output, tf.SparseTensor)
self.assertAllClose(output, np.array([[2, 0, 0], [0, 3, 0], [0, 0, 0]]))
self.assertAllClose(output.values, np.array([2, 3, 0]))

def test_set_vocabulary(self):
layer = layers.StringLookup(
output_mode="int",
Expand Down

0 comments on commit a243d91

Please sign in to comment.