Skip to content

Commit

Permalink
Add validation_split support for backend native tensors.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jan 29, 2024
1 parent 71a601b commit 088f0fc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
4 changes: 2 additions & 2 deletions keras/trainers/data_adapters/data_adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
if backend.backend() == "tensorflow":
from keras.utils.module_utils import tensorflow as tf

ARRAY_TYPES = ARRAY_TYPES + (tf.Tensor, tf.RaggedTensor)
ARRAY_TYPES = ARRAY_TYPES + (tf.RaggedTensor,)
if pandas:
ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame)

Expand Down Expand Up @@ -158,7 +158,7 @@ def train_validation_split(arrays, validation_split):
"""

def _can_split(t):
return isinstance(t, ARRAY_TYPES) or t is None
return backend.is_tensor(t) or isinstance(t, ARRAY_TYPES) or t is None

flat_arrays = tree.flatten(arrays)
unsplitable = [type(t) for t in flat_arrays if not _can_split(t)]
Expand Down
15 changes: 15 additions & 0 deletions keras/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,21 @@ def test_fit_with_val_split(
self.assertIn("loss", history)
self.assertIn("val_loss", history)

# Test with backend-native tensors.
x = ops.ones((dataset_size, 4))
y = ops.zeros((dataset_size, 3))
history = model.fit(
x,
y,
batch_size=batch_size,
steps_per_epoch=steps_per_epoch if use_steps_per_epoch else None,
epochs=epochs,
validation_split=0.2,
)
history = history.history
self.assertIn("loss", history)
self.assertIn("val_loss", history)

@parameterized.named_parameters(
[
("eager_tf_sparse", True, False),
Expand Down

0 comments on commit 088f0fc

Please sign in to comment.