Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
The issue was caused by the fact that the iterator was not fully consumed and `on_epoch_end` was not called.

Added an exception to catch this situation in the future.

Added a unit test to test `model.fit()` with all the combinations of data adapters.
  • Loading branch information
hertschuh authored Nov 5, 2024
1 parent 1a01cbd commit b91cfe5
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 14 deletions.
16 changes: 10 additions & 6 deletions keras/src/trainers/data_adapters/py_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def __init__(
self.enqueuer = None
self.shuffle = shuffle
self._output_signature = None
self._within_epoch = False

workers = self.py_dataset.workers
use_multiprocessing = self.py_dataset.use_multiprocessing
Expand Down Expand Up @@ -314,6 +315,12 @@ def get_torch_dataloader(self):
return data_adapter_utils.get_torch_dataloader(self._get_iterator())

def on_epoch_begin(self):
if self._within_epoch:
raise ValueError(
"`on_epoch_begin` was called twice without `on_epoch_end` "
"having been called."
)
self._within_epoch = True
if self.enqueuer:
self.enqueuer.start()
self.py_dataset.on_epoch_begin()
Expand All @@ -322,6 +329,7 @@ def on_epoch_end(self):
if self.enqueuer:
self.enqueuer.stop()
self.py_dataset.on_epoch_end()
self._within_epoch = False

@property
def num_batches(self):
Expand Down Expand Up @@ -460,7 +468,7 @@ def start(self):
return
self.running = True
self.run_thread = threading.Thread(target=self._run)
self.run_thread.name = f"Worker_{self.uid}" # TODO remove
self.run_thread.name = f"Worker_{self.uid}"
self.run_thread.daemon = True
self.run_thread.start()

Expand Down Expand Up @@ -644,11 +652,7 @@ def get(self):
if inputs is not None:
yield inputs
except queue.Empty:
warnings.warn(
"Generator ran out of batches before reaching `num_batches`"
)
self.stop()
return
pass
except Exception as e:
self.stop(drain_queue_and_join=True)
raise e
Expand Down
21 changes: 13 additions & 8 deletions keras/src/trainers/data_adapters/py_dataset_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,23 @@ def test_basic_flow(
):
if use_multiprocessing and shuffle:
pytest.skip("Starting processes is slow, test fewer variants")
if testing.tensorflow_uses_gpu():
pytest.skip("This test is flaky with TF on GPU")

set_random_seed(1337)
x = np.random.random((64, 4)).astype("float32")
y = np.array([[i, i] for i in range(64)], dtype="float32")
if dataset_type == "tf":
x, y = tf.constant(x), tf.constant(y)
elif dataset_type == "jax":
x, y = jax.numpy.array(x), jax.numpy.array(y)
elif dataset_type == "torch":
x, y = torch.as_tensor(x), torch.as_tensor(y)
CPU_DEVICES = {
"tensorflow": "CPU:0",
"jax": "cpu:0",
"torch": "cpu",
"numpy": "cpu",
}
with backend.device(CPU_DEVICES[backend.backend()]):
if dataset_type == "tf":
x, y = tf.constant(x), tf.constant(y)
elif dataset_type == "jax":
x, y = jax.numpy.array(x), jax.numpy.array(y)
elif dataset_type == "torch":
x, y = torch.as_tensor(x), torch.as_tensor(y)
py_dataset = ExamplePyDataset(
x,
y,
Expand Down
1 change: 1 addition & 0 deletions keras/src/trainers/epoch_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def reset(self):
self._num_batches = self.data_adapter.num_batches
self._steps_seen = 0
self._epoch_iterator = None
self.data_adapter.on_epoch_end()

def _enumerate_iterator(self):
self.data_adapter.on_epoch_begin()
Expand Down
182 changes: 182 additions & 0 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras.src.callbacks.callback import Callback
from keras.src.optimizers.rmsprop import RMSprop
from keras.src.testing.test_utils import named_product
from keras.src.trainers.data_adapters import py_dataset_adapter

if backend.backend() == "jax":
from keras.src.backend.jax.trainer import JAXTrainer as Trainer
Expand Down Expand Up @@ -141,6 +142,82 @@ def call(self, x, training=False):
return x * 0


class TestPyDataset(py_dataset_adapter.PyDataset):
def __init__(self, infinite=False, **kwargs):
super().__init__(**kwargs)
self.infinite = infinite

@property
def num_batches(self):
return None if self.infinite else 20

def __getitem__(self, idx):
CPU_DEVICES = {
"tensorflow": "CPU:0",
"jax": "cpu:0",
"torch": "cpu",
}
with backend.device(CPU_DEVICES[backend.backend()]):
return ops.ones((5, 4)), ops.zeros((5, 3))


def create_dataset(dataset_type, dataset_kwargs):
if dataset_type == "np_array":
return np.ones((100, 4)), np.zeros((100, 3))
elif dataset_type == "native_array":
return ops.ones((100, 4)), ops.zeros((100, 3))
elif dataset_type == "py_dataset":
return TestPyDataset(**dataset_kwargs), None
elif dataset_type == "tf_dataset":
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices(
(tf.ones((100, 4)), tf.zeros((100, 3)))
).batch(5)
if dataset_kwargs.get("infinite", False):
dataset = dataset.repeat()
return dataset, None
elif dataset_type == "torch_dataloader":
import torch

class TestIterableDataset(torch.utils.data.IterableDataset):
def __iter__(self):
for _ in range(20):
yield torch.ones((5, 4)), torch.zeros((5, 3))

class TestIterableDatasetWithLen(TestIterableDataset):
def __len__(self):
return 20

if dataset_kwargs.get("iterable", False):
if dataset_kwargs.get("has_len", False):
dataset = TestIterableDatasetWithLen()
else:
dataset = TestIterableDataset()
return torch.utils.data.DataLoader(dataset), None
else:
dataset = torch.utils.data.TensorDataset(
torch.ones((100, 4)), torch.zeros((100, 3))
)
return torch.utils.data.DataLoader(dataset, batch_size=5), None
elif dataset_type == "generator":

def generate_finite():
for _ in range(20):
yield ops.ones((5, 4)), ops.zeros((5, 3))

def generate_infinite():
while True:
yield ops.ones((5, 4)), ops.zeros((5, 3))

if dataset_kwargs.get("infinite", False):
return generate_infinite(), None
else:
return generate_finite(), None
else:
raise ValueError(f"Invalid dataset type {dataset_type}")


def sparse_generator(generator_type):
if generator_type == "scipy":
import scipy
Expand Down Expand Up @@ -397,6 +474,111 @@ def test_fit_flow(self, run_eagerly, jit_compile, use_steps_per_epoch):
atol=1.0, # TODO: results vary across backends
)

@parameterized.named_parameters(
[
{
"testcase_name": "np_array",
"dataset_type": "np_array",
"fit_kwargs": {"batch_size": 5},
},
{
"testcase_name": "native_array",
"dataset_type": "native_array",
"fit_kwargs": {"batch_size": 5},
},
{
"testcase_name": "py_dataset",
"dataset_type": "py_dataset",
},
{
"testcase_name": "py_dataset_infinite",
"dataset_type": "py_dataset",
"dataset_kwargs": {"infinite": True},
"fit_kwargs": {"steps_per_epoch": 20},
},
{
"testcase_name": "py_dataset_multithreading",
"dataset_type": "py_dataset",
"dataset_kwargs": {"workers": 2},
},
{
"testcase_name": "py_dataset_multithreading_infinite",
"dataset_type": "py_dataset",
"dataset_kwargs": {"infinite": True, "workers": 2},
"fit_kwargs": {"steps_per_epoch": 20},
},
{
"testcase_name": "py_dataset_multiprocessing",
"dataset_type": "py_dataset",
"dataset_kwargs": {"workers": 2, "use_multiprocessing": True},
},
{
"testcase_name": "py_dataset_multiprocessing_infinite",
"dataset_type": "py_dataset",
"dataset_kwargs": {
"infinite": True,
"workers": 2,
"use_multiprocessing": True,
},
"fit_kwargs": {"steps_per_epoch": 20},
},
{
"testcase_name": "tf_dataset",
"dataset_type": "tf_dataset",
},
{
"testcase_name": "tf_dataset_infinite",
"dataset_type": "tf_dataset",
"dataset_kwargs": {"infinite": True},
"fit_kwargs": {"steps_per_epoch": 20},
},
{
"testcase_name": "torch_dataloader_tensor",
"dataset_type": "torch_dataloader",
},
{
"testcase_name": "torch_dataloader_iterable",
"dataset_type": "torch_dataloader",
"dataset_kwargs": {"iterable": True, "has_len": False},
},
{
"testcase_name": "torch_dataloader_iterable_with_len",
"dataset_type": "torch_dataloader",
"dataset_kwargs": {"iterable": True, "has_len": True},
},
{
"testcase_name": "generator",
"dataset_type": "generator",
},
{
"testcase_name": "generator_infinite",
"dataset_type": "generator",
"dataset_kwargs": {"infinite": True},
"fit_kwargs": {"steps_per_epoch": 20},
},
]
)
@pytest.mark.requires_trainable_backend
def test_fit_with_data_adapter(
self, dataset_type, dataset_kwargs={}, fit_kwargs={}
):
if (
dataset_kwargs.get("use_multiprocessing", False)
and backend.backend() == "jax"
):
pytest.skip("Multiprocessing not supported with JAX backend")

model = ExampleModel(units=3)
optimizer = optimizers.Adagrad()
model.compile(
optimizer=optimizer,
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
jit_compile=True,
)
x, y = create_dataset(dataset_type, dataset_kwargs)
model.fit(x, y, epochs=3, **fit_kwargs)

@parameterized.named_parameters(
[
("eager", True, False, False),
Expand Down

0 comments on commit b91cfe5

Please sign in to comment.