diff --git a/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc b/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc index 07307e5eb..eb82ae0c3 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc +++ b/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc @@ -358,8 +358,24 @@ def_data_pipeline(py::module_ &data_module) return data_pipeline::count(start, std::move(key)); }, py::arg("start") = 0, - py::arg("key") = std::nullopt); + py::arg("key") = std::nullopt) + .def_static( + "concat", + [](std::vector> &refs) + { + std::vector pipelines{}; + + pipelines.reserve(refs.size()); + std::transform( + refs.begin(), refs.end(), std::back_inserter(pipelines), [](auto &r) { + return std::move(r.get()); + }); + + return data_pipeline::concat(std::move(pipelines)); + }, + py::arg("pipelines")); + // DataPipelineIterator py::class_(m, "_DataPipelineIterator") .def( diff --git a/fairseq2n/src/fairseq2n/CMakeLists.txt b/fairseq2n/src/fairseq2n/CMakeLists.txt index c3e68d1bb..19dd4a873 100644 --- a/fairseq2n/src/fairseq2n/CMakeLists.txt +++ b/fairseq2n/src/fairseq2n/CMakeLists.txt @@ -18,6 +18,7 @@ target_sources(fairseq2n data/bucket_data_source.cc data/byte_stream.cc data/collater.cc + data/concat_data_source.cc data/constant_data_source.cc data/count_data_source.cc data/data.cc diff --git a/fairseq2n/src/fairseq2n/data/concat_data_source.cc b/fairseq2n/src/fairseq2n/data/concat_data_source.cc new file mode 100644 index 000000000..755c39f38 --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/concat_data_source.cc @@ -0,0 +1,47 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "fairseq2n/data/concat_data_source.h" +#include + +namespace fairseq2n::detail { + +concat_data_source::concat_data_source(std::vector &&pipelines) + : pipelines_(std::move(pipelines)) +{} + +std::optional +concat_data_source::next() +{ + std::optional d; + for (auto &p : pipelines_) { + d = p.next(); + if (d) + return d; + } + return {}; +} + +void concat_data_source::reset() +{ + for (auto &pipeline : pipelines_) + pipeline.reset(); +} + +void concat_data_source::record_position(tape &t) const +{ + for (auto &pipeline : pipelines_) + pipeline.record_position(t); +} + +void concat_data_source::reload_position(tape &t) +{ + for (auto &pipeline : pipelines_) + pipeline.reload_position(t); +} + +} // namespace fairseq2n::detail + diff --git a/fairseq2n/src/fairseq2n/data/concat_data_source.h b/fairseq2n/src/fairseq2n/data/concat_data_source.h new file mode 100644 index 000000000..236538029 --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/concat_data_source.h @@ -0,0 +1,40 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +#include "fairseq2n/data/data_pipeline.h" +#include "fairseq2n/data/data_source.h" + +namespace fairseq2n::detail { + +class concat_data_source final : public data_source { +public: + explicit + concat_data_source( + std::vector &&pipelines); + + std::optional + next() override; + + void + reset() override; + + void + record_position(tape &t) const override; + + void + reload_position(tape &t) override; + +private: + std::vector pipelines_; +}; + +} // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/data_pipeline.cc b/fairseq2n/src/fairseq2n/data/data_pipeline.cc index afacf9fae..7b651450f 100644 --- a/fairseq2n/src/fairseq2n/data/data_pipeline.cc +++ b/fairseq2n/src/fairseq2n/data/data_pipeline.cc @@ -11,25 +11,27 @@ #include #include +#include "data_pipeline.h" #include "fairseq2n/data/bucket_by_length_data_source.h" #include "fairseq2n/data/bucket_data_source.h" +#include "fairseq2n/data/concat_data_source.h" #include "fairseq2n/data/constant_data_source.h" #include "fairseq2n/data/count_data_source.h" +#include "fairseq2n/data/detail/file_system.h" #include "fairseq2n/data/filter_data_source.h" #include "fairseq2n/data/list_data_source.h" #include "fairseq2n/data/map_data_source.h" #include "fairseq2n/data/prefetch_data_source.h" -#include "fairseq2n/data/take_data_source.h" #include "fairseq2n/data/round_robin_data_source.h" #include "fairseq2n/data/sample_data_source.h" #include "fairseq2n/data/shard_data_source.h" #include "fairseq2n/data/shuffle_data_source.h" #include "fairseq2n/data/skip_data_source.h" +#include "fairseq2n/data/take_data_source.h" #include "fairseq2n/data/tape.h" #include "fairseq2n/data/yield_from_data_source.h" #include "fairseq2n/data/zip_data_source.h" #include "fairseq2n/data/zip_file_data_source.h" -#include "fairseq2n/data/detail/file_system.h" #include "fairseq2n/detail/exception.h" using namespace fairseq2n::detail; @@ -283,6 +285,34 @@ data_pipeline::count(std::int64_t start, std::optional key) return data_pipeline_builder{std::move(factory)}; } +data_pipeline_builder +data_pipeline::concat( + std::vector pipelines) +{ + if (pipelines.empty()) + throw_( + "`pipelines` does not contain any elements. Can not concatenate from empty set."); + + bool is_broken = std::any_of( + pipelines.begin(), pipelines.end(), [](const data_pipeline &pipeline) + { + return pipeline.is_broken(); + }); + + if (is_broken) + throw_( + "At least one of the specified data pipelines is broken and cannot be concatenated."); + + auto tmp = std::make_shared>(std::move(pipelines)); + + auto factory = [tmp]() mutable + { + return std::make_unique(std::move(*tmp)); + }; + + return data_pipeline_builder{std::move(factory)}; +} + data_pipeline_builder data_pipeline_builder::bucket(std::size_t bucket_size, bool drop_remainder) && { diff --git a/fairseq2n/src/fairseq2n/data/data_pipeline.h b/fairseq2n/src/fairseq2n/data/data_pipeline.h index 28d272488..350541992 100644 --- a/fairseq2n/src/fairseq2n/data/data_pipeline.h +++ b/fairseq2n/src/fairseq2n/data/data_pipeline.h @@ -97,6 +97,9 @@ class FAIRSEQ2_API data_pipeline { static data_pipeline_builder count(std::int64_t start = 0, std::optional key = {}); + static data_pipeline_builder + concat(std::vector pipelines); + private: data_source_factory factory_{}; std::unique_ptr source_{}; diff --git a/src/fairseq2/data/data_pipeline.py b/src/fairseq2/data/data_pipeline.py index edd8434ec..eb88f6b06 100644 --- a/src/fairseq2/data/data_pipeline.py +++ b/src/fairseq2/data/data_pipeline.py @@ -135,6 +135,15 @@ def constant(example: Any, key: Optional[str] = None) -> "DataPipelineBuilder": def count(start: int = 0, key: Optional[str] = None) -> "DataPipelineBuilder": ... + @staticmethod + def concat(pipelines: Sequence["DataPipeline"]) -> "DataPipelineBuilder": + """Concatenate examples from ``pipelines``. + + :param pipelines: + The data pipelines to concatenate. + """ + ... + class DataPipelineBuilder: """API to create DataPipeline""" diff --git a/tests/unit/data/data_pipeline/test_concat.py b/tests/unit/data/data_pipeline/test_concat.py new file mode 100644 index 000000000..659bf3f8b --- /dev/null +++ b/tests/unit/data/data_pipeline/test_concat.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +from fairseq2.data import DataPipeline, DataPipelineError, read_sequence +from fairseq2.data.text import read_text +from tests.common import python_devel_only + + +@pytest.mark.skipif( + python_devel_only(), + reason="New fairseq2n API in Python-only installation. Skipping till v0.2.", +) +class TestConcatOp: + def test_op_works(self) -> None: + pipeline1 = read_sequence([1, 2, 3, 4]).and_return() + pipeline2 = read_sequence([5, 6, 7, 8]).and_return() + + pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return() + + for _ in range(2): + assert list(pipeline) == [1, 2, 3, 4, 5, 6, 7, 8] + pipeline.reset() + + def test_op_works_when_pipelines_are_empty(self) -> None: + pipeline1 = read_sequence([]).and_return() + pipeline2 = read_sequence([]).and_return() + + pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return() + + for _ in range(2): + assert list(pipeline) == [] + pipeline.reset() + + def test_op_works_when_pipelines_have_different_lengths(self) -> None: + pipeline1 = read_sequence([1, 2, 3]).and_return() + pipeline2 = read_sequence([4, 5]).and_return() + + pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return() + + for _ in range(2): + assert list(pipeline) == [1, 2, 3, 4, 5] + pipeline.reset() + + def test_op_raises_error_when_one_of_the_pipelines_is_broken(self) -> None: + # Force a non-recoverable error. + pipeline1 = read_text(pathname=" &^#").and_return() + pipeline2 = read_text(pathname=" &^#").and_return() + + # Break the first pipeline. + try: + next(iter(pipeline1)) + except DataPipelineError: + assert pipeline1.is_broken + + with pytest.raises( + ValueError, + match=r"^At least one of the specified data pipelines is broken and cannot be concatenated\.$", + ): + DataPipeline.concat([pipeline1, pipeline2]).and_return() + + def test_op_saves_and_restores_its_state(self) -> None: + pipeline1 = read_sequence([1, 2, 3, 4]).and_return() + pipeline2 = read_sequence([5, 6, 7, 8]).and_return() + + pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return() + + d = None + + it = iter(pipeline) + + # Move the the second example. + for _ in range(2): + d = next(it) + + assert d == 2 + + state_dict = pipeline.state_dict() + + # Read one more example before we roll back. + d = next(it) + + assert d == 3 + + # Expected to roll back to the second example. + pipeline.load_state_dict(state_dict) + + # Move to EOD. + for _ in range(6): + d = next(it) + + assert d == 8 + + state_dict = pipeline.state_dict() + + pipeline.reset() + + # Expected to be EOD. + pipeline.load_state_dict(state_dict) + + with pytest.raises(StopIteration): + next(iter(pipeline))