From 52b05c79172302d1201b7c3d7439d991f29e9989 Mon Sep 17 00:00:00 2001 From: Naji El Hachem Date: Thu, 5 Oct 2023 16:18:54 +0200 Subject: [PATCH] Introduce stop_at_shortest in sample and round_robin (#76) * Add tensor.h to utils * Add sample_data_source class * Add sample method to data_pipeline * Add sample binding to data_pipeline * Add test_sample * Fix build errors * fix python lint errors * Add empty lines at end of c++ files * remove \n from runtime error regex match * Use float32 instead of float * Resolve nit comments * Remove set_seed from sampler * Fix document error * Handle generator state out of sample operator * Add circular data source * Use circular data_soruce in round_robin * fix circular data source * Add stop_on_shortest flag * Add stop_at_shortest flag to round_robin * Add stop_at_shortest round_robin test * Fix circular data source stop_at_shortest flag * Use circular data_source in sample op * Add up_sampling test * Rename circular_data_source to multi_data_source * Improve sample and round_robin doc * lint code * skip test that uses new API * nit changes * rename multi_data_source to composite_data_source * use std::exchange instead of = * resolve last nit comment --- .../fairseq2n/bindings/data/data_pipeline.cc | 17 ++- fairseq2n/src/fairseq2n/CMakeLists.txt | 1 + .../fairseq2n/data/composite_data_source.cc | 114 ++++++++++++++++++ .../fairseq2n/data/composite_data_source.h | 51 ++++++++ fairseq2n/src/fairseq2n/data/data_pipeline.cc | 15 ++- fairseq2n/src/fairseq2n/data/data_pipeline.h | 7 +- .../fairseq2n/data/round_robin_data_source.cc | 101 ++++------------ .../fairseq2n/data/round_robin_data_source.h | 15 +-- .../src/fairseq2n/data/sample_data_source.cc | 45 +++---- .../src/fairseq2n/data/sample_data_source.h | 6 +- src/fairseq2/data/data_pipeline.py | 14 ++- .../data/data_pipeline/test_round_robin.py | 23 ++++ tests/unit/data/data_pipeline/test_sample.py | 39 ++++-- 13 files changed, 302 insertions(+), 146 deletions(-) create mode 100644 fairseq2n/src/fairseq2n/data/composite_data_source.cc create mode 100644 fairseq2n/src/fairseq2n/data/composite_data_source.h diff --git a/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc b/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc index 4bbedb312..176d03929 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc +++ b/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc @@ -304,7 +304,9 @@ def_data_pipeline(py::module_ &data_module) py::arg("disable_parallelism") = false) .def_static( "round_robin", - [](std::vector> &refs) + []( + std::vector> &refs, + bool stop_at_shortest) { std::vector pipelines{}; @@ -315,14 +317,16 @@ def_data_pipeline(py::module_ &data_module) return std::move(r.get()); }); - return data_pipeline::round_robin(std::move(pipelines)); + return data_pipeline::round_robin(std::move(pipelines), stop_at_shortest); }, - py::arg("pipelines")) + py::arg("pipelines"), + py::arg("stop_at_shortest") = false) .def_static( "sample", []( std::vector> &refs, - std::optional> weights) + std::optional> weights, + bool stop_at_shortest) { std::vector pipelines{}; @@ -334,10 +338,11 @@ def_data_pipeline(py::module_ &data_module) }); return data_pipeline::sample( - std::move(pipelines), std::move(weights)); + std::move(pipelines), std::move(weights), stop_at_shortest); }, py::arg("pipelines"), - py::arg("weights") = std::nullopt) + py::arg("weights") = std::nullopt, + py::arg("stop_at_shortest") = false) .def_static( "constant", [](data example, std::optional key) diff --git a/fairseq2n/src/fairseq2n/CMakeLists.txt b/fairseq2n/src/fairseq2n/CMakeLists.txt index b11c7cce2..b642d5d0a 100644 --- a/fairseq2n/src/fairseq2n/CMakeLists.txt +++ b/fairseq2n/src/fairseq2n/CMakeLists.txt @@ -34,6 +34,7 @@ target_sources(fairseq2n data/list_data_source.cc data/map_data_source.cc data/memory_stream.cc + data/composite_data_source.cc data/prefetch_data_source.cc data/py.cc data/record_reader.cc diff --git a/fairseq2n/src/fairseq2n/data/composite_data_source.cc b/fairseq2n/src/fairseq2n/data/composite_data_source.cc new file mode 100644 index 000000000..34cd1c0d3 --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/composite_data_source.cc @@ -0,0 +1,114 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "fairseq2n/data/composite_data_source.h" + +namespace fairseq2n::detail { + +composite_data_source::composite_data_source( + std::vector &&pipelines, + index_generator_fn &&index_gen_fn, + bool stop_at_shortest) + : pipelines_(std::move(pipelines)), + next_index_gen_{std::move(index_gen_fn)}, + stop_at_shortest_{stop_at_shortest} +{ + if (!stop_at_shortest) { + is_epoch_done_ = std::vector(pipelines_.size(), false); + buffer_ = std::vector>(pipelines_.size(), std::nullopt); + } +} + +std::optional +composite_data_source::next() +{ + if (stop_at_shortest_) // with this flag on, the operator is a simple iterator + return pipelines_[next_index_gen_()].next(); + + // One or more data pipelines might be empty, so we have to keep looping + std::optional output{}; + while (!output && !eod()) { + auto pipeline_idx = next_index_gen_(); + auto &maybe_example = buffer_[pipeline_idx]; + + if (!maybe_example) // init buffer at first call + maybe_example = next_in_pipeline(pipeline_idx); + + output = std::exchange(maybe_example, next_in_pipeline(pipeline_idx)); + } + + return output; +} + +void +composite_data_source::reset() +{ + for (data_pipeline &pipeline : pipelines_) + pipeline.reset(); + + if (!stop_at_shortest_) { + buffer_.assign(pipelines_.size(), std::nullopt); + is_epoch_done_.assign(pipelines_.size(), false); + is_eod_ = false; + } +} + +void +composite_data_source::record_position(tape &t) const +{ + for (const data_pipeline &pipeline : pipelines_) + pipeline.record_position(t); + + if (!stop_at_shortest_) { + t.record(buffer_); + t.record(is_epoch_done_); + } +} + +void +composite_data_source::reload_position(tape &t) +{ + for (data_pipeline &pipeline : pipelines_) + pipeline.reload_position(t); + + if (!stop_at_shortest_) { + buffer_ = t.read>>(); + is_epoch_done_ = t.read>(); + is_eod_ = false; + } +} + +std::optional +composite_data_source::next_in_pipeline(std::size_t pipeline_idx) +{ + data_pipeline &pipeline = pipelines_[pipeline_idx]; + + std::optional maybe_example = pipeline.next(); + if (!maybe_example) { + is_epoch_done_[pipeline_idx] = true; + + pipeline.reset(); + + // Circle back to the first example. + maybe_example = pipeline.next(); + } + + return maybe_example; +} + +bool +composite_data_source::eod() +{ + is_eod_ = is_eod_ || std::all_of( + is_epoch_done_.begin(), is_epoch_done_.end(), [](bool b) + { + return b; + }); + + return is_eod_; +} + +} // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/composite_data_source.h b/fairseq2n/src/fairseq2n/data/composite_data_source.h new file mode 100644 index 000000000..bcb13e673 --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/composite_data_source.h @@ -0,0 +1,51 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include "fairseq2n/data/data_pipeline.h" + + +using index_generator_fn = std::function; + +namespace fairseq2n::detail { + +class composite_data_source final : public data_source { +public: + explicit + composite_data_source(std::vector &&pipelines, index_generator_fn &&index_gen_fn, bool stop_at_shortest); + + std::optional + next() override; + + void + reset() override; + + void + record_position(tape &t) const override; + + void + reload_position(tape &t) override; + +private: + std::optional + next_in_pipeline(std::size_t pipeline_idx); + + bool + eod(); + +private: + std::vector pipelines_; + index_generator_fn next_index_gen_; + std::vector> buffer_{}; + std::vector is_epoch_done_; + bool is_eod_ = false; + bool stop_at_shortest_; +}; + +} // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/data_pipeline.cc b/fairseq2n/src/fairseq2n/data/data_pipeline.cc index 0b6abdcfc..afacf9fae 100644 --- a/fairseq2n/src/fairseq2n/data/data_pipeline.cc +++ b/fairseq2n/src/fairseq2n/data/data_pipeline.cc @@ -203,7 +203,9 @@ data_pipeline::zip( } data_pipeline_builder -data_pipeline::round_robin(std::vector pipelines) +data_pipeline::round_robin( + std::vector pipelines, + bool stop_at_shortest) { bool is_broken = std::any_of( pipelines.begin(), pipelines.end(), [](const data_pipeline &pipeline) @@ -217,9 +219,9 @@ data_pipeline::round_robin(std::vector pipelines) auto tmp = std::make_shared>(std::move(pipelines)); - auto factory = [tmp]() mutable + auto factory = [tmp, stop_at_shortest]() mutable { - return std::make_unique(std::move(*tmp)); + return std::make_unique(std::move(*tmp), stop_at_shortest); }; return data_pipeline_builder{std::move(factory)}; @@ -228,7 +230,8 @@ data_pipeline::round_robin(std::vector pipelines) data_pipeline_builder data_pipeline::sample( std::vector pipelines, - std::optional> weights) + std::optional> weights, + bool stop_at_shortest) { if (pipelines.empty()) throw_( @@ -251,8 +254,8 @@ data_pipeline::sample( auto tmp = std::make_shared>(std::move(pipelines)); - auto factory = [tmp, weights=std::move(weights.value())]() mutable { - return std::make_unique(std::move(*tmp), std::move(weights)); + auto factory = [tmp, weights=std::move(weights.value()), stop_at_shortest]() mutable { + return std::make_unique(std::move(*tmp), std::move(weights), stop_at_shortest); }; return data_pipeline_builder{std::move(factory)}; diff --git a/fairseq2n/src/fairseq2n/data/data_pipeline.h b/fairseq2n/src/fairseq2n/data/data_pipeline.h index 341bc0f8e..28d272488 100644 --- a/fairseq2n/src/fairseq2n/data/data_pipeline.h +++ b/fairseq2n/src/fairseq2n/data/data_pipeline.h @@ -81,12 +81,15 @@ class FAIRSEQ2_API data_pipeline { bool disable_parallelism = false); static data_pipeline_builder - round_robin(std::vector pipelines); + round_robin( + std::vector pipelines, + bool stop_at_shortest = false); static data_pipeline_builder sample( std::vector pipelines, - std::optional> weights = {}); + std::optional> weights = {}, + bool stop_at_shortest = false); static data_pipeline_builder constant(data example, std::optional key = {}); diff --git a/fairseq2n/src/fairseq2n/data/round_robin_data_source.cc b/fairseq2n/src/fairseq2n/data/round_robin_data_source.cc index a81965491..3caa0b045 100644 --- a/fairseq2n/src/fairseq2n/data/round_robin_data_source.cc +++ b/fairseq2n/src/fairseq2n/data/round_robin_data_source.cc @@ -8,110 +8,53 @@ namespace fairseq2n::detail { -std::optional -round_robin_data_source::next() +round_robin_data_source::round_robin_data_source(std::vector &&pipelines, bool stop_at_shortest) { - if (pipelines_.empty() || is_eod_) - return std::nullopt; - - // At the beginning of the next round, check if all data pipelines had at - // least one epoch. If that is the case, we can signal EOD. - if (buffer_idx_ == 0) { - bool all_done = std::all_of( - is_epoch_done_.begin(), is_epoch_done_.end(), [](bool b) - { - return b; - }); + pipelines_count_ = pipelines.size(); + pipeline_idx_ = 0; - if (all_done) { - is_eod_ = true; + auto gen = [this]() + { + pipeline_idx_ %= pipelines_count_; - return std::nullopt; - } - } + return pipeline_idx_++; + }; - // If this is the first call, gather the first round of examples. - if (buffer_.empty()) - for (std::size_t i = 0; i < pipelines_.size(); i++) - buffer_.push_back(next_in_pipeline(i)); - - std::optional output{}; - - // One or more data pipelines might be empty, so we have to check if a - // buffered example has a value before returning it. - for (; !output && buffer_idx_ < buffer_.size(); buffer_idx_++) { - std::optional &maybe_example = buffer_[buffer_idx_]; - if (maybe_example) - // Fill the position with the next round's example. - output = std::exchange(maybe_example, next_in_pipeline(buffer_idx_)); - } + inner_ = std::make_unique(std::move(pipelines), std::move(gen), stop_at_shortest); +} - if (buffer_idx_ == buffer_.size()) - buffer_idx_ = 0; +std::optional +round_robin_data_source::next() +{ + auto output = inner_->next(); + if (!output) + return std::nullopt; - // Might not have a value if all data pipelines were empty. return output; } void round_robin_data_source::reset() { - buffer_.clear(); - - buffer_idx_ = 0; - - is_epoch_done_.assign(pipelines_.size(), false); - - is_eod_ = false; + inner_->reset(); - for (data_pipeline &pipeline : pipelines_) - pipeline.reset(); + pipeline_idx_ = 0; } void round_robin_data_source::record_position(tape &t) const { - t.record(buffer_); + inner_->record_position(t); - t.record(buffer_idx_); - - t.record(is_epoch_done_); - - for (const data_pipeline &pipeline : pipelines_) - pipeline.record_position(t); + t.record(pipeline_idx_); } void round_robin_data_source::reload_position(tape &t) { - buffer_ = t.read>>(); - - buffer_idx_ = t.read(); - - is_epoch_done_ = t.read>(); - - is_eod_ = false; - - for (data_pipeline &pipeline : pipelines_) - pipeline.reload_position(t); -} - -std::optional -round_robin_data_source::next_in_pipeline(std::size_t pipeline_idx) -{ - data_pipeline &pipeline = pipelines_[pipeline_idx]; - - std::optional maybe_example = pipeline.next(); - if (!maybe_example) { - is_epoch_done_[pipeline_idx] = true; - - pipeline.reset(); - - // Circle back to the first example. - maybe_example = pipeline.next(); - } + inner_->reload_position(t); - return maybe_example; + pipeline_idx_ = t.read(); } } // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/round_robin_data_source.h b/fairseq2n/src/fairseq2n/data/round_robin_data_source.h index b450bb6ef..c43558ad4 100644 --- a/fairseq2n/src/fairseq2n/data/round_robin_data_source.h +++ b/fairseq2n/src/fairseq2n/data/round_robin_data_source.h @@ -12,17 +12,14 @@ #include "fairseq2n/data/data_pipeline.h" #include "fairseq2n/data/data_source.h" +#include "fairseq2n/data/composite_data_source.h" namespace fairseq2n::detail { class round_robin_data_source final : public data_source { public: explicit - round_robin_data_source(std::vector &&pipelines) - : pipelines_(std::move(pipelines)), is_epoch_done_(pipelines_.size()) - { - buffer_.reserve(pipelines_.size()); - } + round_robin_data_source(std::vector &&pipelines, bool stop_at_shortest); std::optional next() override; @@ -41,11 +38,9 @@ class round_robin_data_source final : public data_source { next_in_pipeline(std::size_t pipeline_idx); private: - std::vector pipelines_; - std::vector> buffer_{}; - std::size_t buffer_idx_ = 0; - std::vector is_epoch_done_; - bool is_eod_ = false; + std::unique_ptr inner_; + std::size_t pipeline_idx_; + std::size_t pipelines_count_; }; } // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/sample_data_source.cc b/fairseq2n/src/fairseq2n/data/sample_data_source.cc index 17c1f9511..173ea3612 100644 --- a/fairseq2n/src/fairseq2n/data/sample_data_source.cc +++ b/fairseq2n/src/fairseq2n/data/sample_data_source.cc @@ -10,28 +10,32 @@ #include #include -#include - #include "fairseq2n/utils/tensor.h" namespace fairseq2n::detail { -sample_data_source::sample_data_source(std::vector &&pipelines, std::vector &&weights) - : pipelines_(std::move(pipelines)) +sample_data_source::sample_data_source(std::vector &&pipelines, std::vector &&weights, bool stop_at_shortest) { - weights_ = make_tensor_from_vector(weights, { static_cast(pipelines_.size()) }); + weights_ = make_tensor_from_vector(weights, { static_cast(pipelines.size()) }); generator_ = at::globalContext().defaultGenerator(c10::DeviceType::CPU); + + auto gen = [this]() + { + auto result = at::multinomial(weights_, 1, false, generator_) + .item(); + + return static_cast(result); + }; + + inner_ = std::make_unique(std::move(pipelines), std::move(gen), stop_at_shortest); } std::optional sample_data_source::next() { - if (eod_) - return std::nullopt; - - std::optional output = pipelines_[next_index()].next(); + auto output = inner_->next(); if (!output) - eod_ = true; + return std::nullopt; return output; } @@ -39,34 +43,19 @@ sample_data_source::next() void sample_data_source::reset() { - eod_ = false; - for (data_pipeline &p : pipelines_) - p.reset(); + inner_->reset(); } void sample_data_source::record_position(tape &t) const { - t.record(eod_); - for (const data_pipeline &p : pipelines_) - p.record_position(t); + inner_->record_position(t); } void sample_data_source::reload_position(tape &t) { - eod_ = t.read(); - for (data_pipeline &p : pipelines_) - p.reload_position(t); -} - -std::size_t -sample_data_source::next_index() -{ - auto result = at::multinomial(weights_, 1, false, generator_) - .item(); - - return static_cast(result); + inner_->reload_position(t); } } // namespace fairseq2::detail diff --git a/fairseq2n/src/fairseq2n/data/sample_data_source.h b/fairseq2n/src/fairseq2n/data/sample_data_source.h index a31bb34b1..e26882bfd 100644 --- a/fairseq2n/src/fairseq2n/data/sample_data_source.h +++ b/fairseq2n/src/fairseq2n/data/sample_data_source.h @@ -14,6 +14,7 @@ #include "fairseq2n/data/data_pipeline.h" #include "fairseq2n/data/data_source.h" +#include "fairseq2n/data/composite_data_source.h" namespace fairseq2n::detail { @@ -21,7 +22,7 @@ namespace fairseq2n::detail { class sample_data_source final : public data_source { public: explicit - sample_data_source(std::vector &&pipelines, std::vector &&weights); + sample_data_source(std::vector &&pipelines, std::vector &&weights, bool stop_at_shortest); std::optional next() override; @@ -40,8 +41,7 @@ class sample_data_source final : public data_source { next_index(); private: - std::vector pipelines_; - bool eod_ = false; + std::unique_ptr inner_; at::Generator generator_; at::Tensor weights_; diff --git a/src/fairseq2/data/data_pipeline.py b/src/fairseq2/data/data_pipeline.py index 42edf963f..063e3b676 100644 --- a/src/fairseq2/data/data_pipeline.py +++ b/src/fairseq2/data/data_pipeline.py @@ -94,17 +94,25 @@ def zip( """ @staticmethod - def round_robin(pipelines: Sequence["DataPipeline"]) -> "DataPipelineBuilder": + def round_robin( + pipelines: Sequence["DataPipeline"], + stop_at_shortest: bool = False, + ) -> "DataPipelineBuilder": """Extract examples from ``pipelines`` in round robin. :param pipelines: The data pipelines to round robin. + :param stop_at_shortest: + If ``True``, stop round_robin when first pipeline reaches its end. + If ``False``, circle around finished pipelines until all pipelines + reach their end. """ @staticmethod def sample( pipelines: Sequence["DataPipeline"], weights: Optional[Sequence[float]] = None, + stop_at_shortest: bool = False, ) -> "DataPipelineBuilder": """Extract examples from ``pipelines`` by sampling based on ``weights``. @@ -112,6 +120,10 @@ def sample( The data pipelines to sample from. :param weights: Desired distribution of pipelines. If None, use uniform distribution. + :param stop_at_shortest: + If ``True``, stop sampling when first pipeline reaches its end. + If ``False``, circle around finished pipelines until all pipelines + reach their end. """ @staticmethod diff --git a/tests/unit/data/data_pipeline/test_round_robin.py b/tests/unit/data/data_pipeline/test_round_robin.py index 406de6b46..e8ba0e984 100644 --- a/tests/unit/data/data_pipeline/test_round_robin.py +++ b/tests/unit/data/data_pipeline/test_round_robin.py @@ -8,6 +8,7 @@ from fairseq2.data import DataPipeline, DataPipelineError, read_sequence from fairseq2.data.text import read_text +from tests.common import python_devel_only class TestRoundRobinOp: @@ -76,6 +77,28 @@ def test_op_works_when_pipelines_have_different_lengths(self) -> None: pipeline.reset() + @pytest.mark.skipif( + python_devel_only(), + reason="New fairseq2n API in Python-only installation. Skipping till v0.2.", + ) + def test_op_works_when_pipelines_have_different_lengths_stop_at_shortest( + self, + ) -> None: + pipeline1 = read_sequence([1, 2, 3, 4]).and_return() + pipeline2 = read_sequence([5, 6]).and_return() + pipeline3 = read_sequence([7, 8, 9, 0, 1, 2]).and_return() + + pipeline = DataPipeline.round_robin( + [pipeline1, pipeline2, pipeline3], stop_at_shortest=True + ).and_return() + + seq = [1, 5, 7, 2, 6, 8, 3] + + for _ in range(2): + assert list(pipeline) == seq + + pipeline.reset() + def test_op_raises_error_when_one_of_the_pipelines_is_broken(self) -> None: # Force a non-recoverable error. pipeline1 = read_text(pathname=" &^#").and_return() diff --git a/tests/unit/data/data_pipeline/test_sample.py b/tests/unit/data/data_pipeline/test_sample.py index fcb71254c..e73a50c05 100644 --- a/tests/unit/data/data_pipeline/test_sample.py +++ b/tests/unit/data/data_pipeline/test_sample.py @@ -27,8 +27,9 @@ class TestSampleOp: def test_op_works(self) -> None: dp1 = read_sequence([1, 2, 3]).and_return() dp2 = read_sequence([11, 12, 13]).and_return() - - rdp = DataPipeline.sample([dp1, dp2], [0.5, 0.5]).and_return() + rdp = DataPipeline.sample( + [dp1, dp2], [0.5, 0.5], stop_at_shortest=True + ).and_return() for _ in range(5): with tmp_rng_seed(cpu_device, seed=1234): @@ -40,7 +41,9 @@ def test_op_works_when_pipelines_have_different_lengths(self) -> None: dp1 = read_sequence([1, 2, 3]).and_return() dp2 = read_sequence([11, 12]).and_return() - rdp = DataPipeline.sample([dp1, dp2], [7, 4]).and_return() + rdp = DataPipeline.sample( + [dp1, dp2], [7, 4], stop_at_shortest=True + ).and_return() for _ in range(2): with tmp_rng_seed(cpu_device, seed=1234): @@ -52,8 +55,7 @@ def test_op_works_when_no_weights_is_specified(self) -> None: dp1 = read_sequence([1, 2, 3, 4, 5]).and_return() dp2 = read_sequence([11, 12]).and_return() dp3 = read_sequence([101, 102, 103]).and_return() - - rdp = DataPipeline.sample([dp1, dp2, dp3]).and_return() + rdp = DataPipeline.sample([dp1, dp2, dp3], stop_at_shortest=True).and_return() for _ in range(2): with tmp_rng_seed(cpu_device, seed=1234): @@ -64,8 +66,9 @@ def test_op_works_when_no_weights_is_specified(self) -> None: def test_op_works_when_weight_is_low(self) -> None: dp1 = read_sequence([1, 2, 3, 4, 5]).and_return() dp2 = read_sequence([11, 12]).and_return() - - rdp = DataPipeline.sample([dp1, dp2], [0.9, 0.1]).and_return() + rdp = DataPipeline.sample( + [dp1, dp2], [0.9, 0.1], stop_at_shortest=True + ).and_return() for _ in range(2): with tmp_rng_seed(cpu_device, seed=1234): @@ -75,8 +78,7 @@ def test_op_works_when_weight_is_low(self) -> None: def test_op_works_when_a_single_pipeline_is_specified(self) -> None: dp = read_sequence([1, 2, 3, 4, 5]).and_return() - - rdp = DataPipeline.sample([dp]).and_return() + rdp = DataPipeline.sample([dp], stop_at_shortest=True).and_return() for _ in range(2): with tmp_rng_seed(cpu_device, seed=1234): @@ -101,7 +103,9 @@ def test_op_works_when_seed_is_set_manually(self) -> None: dp1 = read_sequence([1, 2, 3]).and_return() dp2 = read_sequence([11, 12]).and_return() - rdp = DataPipeline.sample([dp1, dp2], [0.4, 0.6]).and_return() + rdp = DataPipeline.sample( + [dp1, dp2], [0.4, 0.6], stop_at_shortest=True + ).and_return() for _ in range(2): with tmp_rng_seed(cpu_device, seed=1234): @@ -115,6 +119,19 @@ def test_op_works_when_seed_is_set_manually(self) -> None: rdp.reset() + def test_op_works_when_up_sampling(self) -> None: + dp1 = read_sequence([1, 2, 3, 4, 5]).and_return() + dp2 = read_sequence([11, 12]).and_return() + + rdp = DataPipeline.sample( + [dp1, dp2], [0.5, 0.5], stop_at_shortest=False + ).and_return() + for _ in range(2): + with tmp_rng_seed(cpu_device, seed=1234): + assert list(rdp) == [11, 1, 12, 2, 3, 11, 4, 12, 11, 12, 5] + + rdp.reset() + def test_op_raises_error_when_weight_is_negative(self) -> None: dl1 = read_sequence([1, 2, 3, 4, 5]).and_return() dl2 = read_sequence([11, 12]).and_return() @@ -180,8 +197,8 @@ def test_op_saves_and_restores_its_state(self) -> None: dp2 = read_sequence(list(range(10, 18))).and_return() dp3 = read_sequence(list(range(20, 26))).and_return() + rdp = DataPipeline.sample([dp1, dp2, dp3], stop_at_shortest=True).and_return() # [10, 0, 11, 20, 1, 21, 22, 23, 24, 12, 2, 25, 13, 3] - rdp = DataPipeline.sample([dp1, dp2, dp3]).and_return() d = None