From a24baa050d432be2a6dce7b88118266fe4db5553 Mon Sep 17 00:00:00 2001 From: Naji El Hachem Date: Mon, 2 Oct 2023 18:36:01 +0200 Subject: [PATCH] Add stop_on_shortest flag --- fairseq2n/src/fairseq2n/data/data_pipeline.cc | 7 ++++--- fairseq2n/src/fairseq2n/data/data_pipeline.h | 3 ++- fairseq2n/src/fairseq2n/data/sample_data_source.cc | 2 +- fairseq2n/src/fairseq2n/data/sample_data_source.h | 2 +- src/fairseq2/data/data_pipeline.py | 3 +++ 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/fairseq2n/src/fairseq2n/data/data_pipeline.cc b/fairseq2n/src/fairseq2n/data/data_pipeline.cc index dbb350215..484b65b79 100644 --- a/fairseq2n/src/fairseq2n/data/data_pipeline.cc +++ b/fairseq2n/src/fairseq2n/data/data_pipeline.cc @@ -228,7 +228,8 @@ data_pipeline::round_robin(std::vector pipelines) data_pipeline_builder data_pipeline::sample( std::vector pipelines, - std::optional> weights) + std::optional> weights, + bool stop_at_shortest) { if (pipelines.empty()) throw_( @@ -251,8 +252,8 @@ data_pipeline::sample( auto tmp = std::make_shared>(std::move(pipelines)); - auto factory = [tmp, weights=std::move(weights.value())]() mutable { - return std::make_unique(std::move(*tmp), std::move(weights)); + auto factory = [tmp, weights=std::move(weights.value()), stop_at_shortest]() mutable { + return std::make_unique(std::move(*tmp), std::move(weights), stop_at_shortest); }; return data_pipeline_builder{std::move(factory)}; diff --git a/fairseq2n/src/fairseq2n/data/data_pipeline.h b/fairseq2n/src/fairseq2n/data/data_pipeline.h index 46469e61a..525a50a73 100644 --- a/fairseq2n/src/fairseq2n/data/data_pipeline.h +++ b/fairseq2n/src/fairseq2n/data/data_pipeline.h @@ -86,7 +86,8 @@ class FAIRSEQ2_API data_pipeline { static data_pipeline_builder sample( std::vector pipelines, - std::optional> weights = {}); + std::optional> weights = {}, + bool stop_at_shortest = true); static data_pipeline_builder constant(data example, std::optional key = {}); diff --git a/fairseq2n/src/fairseq2n/data/sample_data_source.cc b/fairseq2n/src/fairseq2n/data/sample_data_source.cc index 17c1f9511..a8e8c29d2 100644 --- a/fairseq2n/src/fairseq2n/data/sample_data_source.cc +++ b/fairseq2n/src/fairseq2n/data/sample_data_source.cc @@ -16,7 +16,7 @@ namespace fairseq2n::detail { -sample_data_source::sample_data_source(std::vector &&pipelines, std::vector &&weights) +sample_data_source::sample_data_source(std::vector &&pipelines, std::vector &&weights, bool stop_at_shortest) : pipelines_(std::move(pipelines)) { weights_ = make_tensor_from_vector(weights, { static_cast(pipelines_.size()) }); diff --git a/fairseq2n/src/fairseq2n/data/sample_data_source.h b/fairseq2n/src/fairseq2n/data/sample_data_source.h index a31bb34b1..408037592 100644 --- a/fairseq2n/src/fairseq2n/data/sample_data_source.h +++ b/fairseq2n/src/fairseq2n/data/sample_data_source.h @@ -21,7 +21,7 @@ namespace fairseq2n::detail { class sample_data_source final : public data_source { public: explicit - sample_data_source(std::vector &&pipelines, std::vector &&weights); + sample_data_source(std::vector &&pipelines, std::vector &&weights, bool stop_at_shortest); std::optional next() override; diff --git a/src/fairseq2/data/data_pipeline.py b/src/fairseq2/data/data_pipeline.py index 42edf963f..2cea5b62c 100644 --- a/src/fairseq2/data/data_pipeline.py +++ b/src/fairseq2/data/data_pipeline.py @@ -105,6 +105,7 @@ def round_robin(pipelines: Sequence["DataPipeline"]) -> "DataPipelineBuilder": def sample( pipelines: Sequence["DataPipeline"], weights: Optional[Sequence[float]] = None, + stop_at_shortest: bool = True, ) -> "DataPipelineBuilder": """Extract examples from ``pipelines`` by sampling based on ``weights``. @@ -112,6 +113,8 @@ def sample( The data pipelines to sample from. :param weights: Desired distribution of pipelines. If None, use uniform distribution. + :param stop_at_shortest: + Flag to stop sampling when first pipeline reaches its end. """ @staticmethod