From 73e348f99ca9ad6ba3b32a8a4a2544b1a468e75b Mon Sep 17 00:00:00 2001 From: Naji El Hachem Date: Mon, 2 Oct 2023 18:28:27 +0200 Subject: [PATCH] fix circular data source --- .../src/fairseq2n/data/circular_data_source.cc | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/fairseq2n/src/fairseq2n/data/circular_data_source.cc b/fairseq2n/src/fairseq2n/data/circular_data_source.cc index 0fd12cb47..f5b8e7a44 100644 --- a/fairseq2n/src/fairseq2n/data/circular_data_source.cc +++ b/fairseq2n/src/fairseq2n/data/circular_data_source.cc @@ -18,16 +18,17 @@ circular_data_source::circular_data_source(std::vector &&pipeline std::optional circular_data_source::next() { - if (eod()) - return {}; + // One or more data pipelines might be empty, so we have to keep looping + std::optional output{}; + while (!output && !eod()) { + auto pipeline_idx = next_index_gen_(); - auto pipeline_idx = next_index_gen_(); + if (!buffer_[pipeline_idx]) // init buffer at index + buffer_[pipeline_idx] = next_in_pipeline(pipeline_idx); - if (!buffer_[pipeline_idx]) // init buffer at index + output = buffer_[pipeline_idx]; buffer_[pipeline_idx] = next_in_pipeline(pipeline_idx); - - auto output = buffer_[pipeline_idx]; - buffer_[pipeline_idx] = next_in_pipeline(pipeline_idx); + } return output; } @@ -35,7 +36,7 @@ circular_data_source::next() void circular_data_source::reset() { - buffer_.clear(); + buffer_.assign(pipelines_.size(), std::nullopt); is_epoch_done_.assign(pipelines_.size(), false);