diff --git a/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc b/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc index 176d03929..07307e5eb 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc +++ b/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc @@ -403,6 +403,32 @@ def_data_pipeline(py::module_ &data_module) py::arg("bucket_sizes"), py::arg("selector") = std::nullopt, py::arg("drop_remainder") = false) + .def( + "collate", + []( + data_pipeline_builder &self, + std::optional maybe_pad_idx, + std::int64_t pad_to_multiple, + std::optional> maybe_opt_overrides, + std::size_t num_parallel_calls) -> data_pipeline_builder & + { + auto opts = collate_options() + .maybe_pad_idx(maybe_pad_idx).pad_to_multiple(pad_to_multiple); + + std::vector opt_overrides{}; + if (maybe_opt_overrides) + opt_overrides = *std::move(maybe_opt_overrides); + + map_fn f = collater(opts, std::move(opt_overrides)); + + self = std::move(self).map(std::move(f), num_parallel_calls); + + return self; + }, + py::arg("pad_idx") = std::nullopt, + py::arg("pad_to_multiple") = 1, + py::arg("overrides") = std::nullopt, + py::arg("num_parallel_calls") = 1) .def( "filter", [](data_pipeline_builder &self, predicate_fn fn) -> data_pipeline_builder & diff --git a/src/fairseq2/data/data_pipeline.py b/src/fairseq2/data/data_pipeline.py index 063e3b676..edd8434ec 100644 --- a/src/fairseq2/data/data_pipeline.py +++ b/src/fairseq2/data/data_pipeline.py @@ -33,8 +33,9 @@ class DataPipeline(Iterable[Any]): The pipeline state can be persisted to the disk, allowing it to be resumed later. It is a Python Iterable, but it also contains the iterator states. - Calling `iter` a second time while the first iterator is still being used - will segfault or worse. + + Calling `iter` twice will create two iterators reading from the same dataloader, + and sharing the same state, so it will behave inconcistently. """ def __iter__(self) -> Iterator[Any]: @@ -155,6 +156,18 @@ def bucket_by_length( ) -> Self: """Combine examples of similar shape into batches.""" + def collate( + self, + pad_idx: Optional[int] = None, + pad_to_multiple: int = 1, + overrides: Optional[Sequence["CollateOptionsOverride"]] = None, + ) -> Self: + """Concatenate a list of inputs into a single inputs. + + This is equivalent to calling `.map(Collater())`. + See :py:class:`fairseq2.data.Collater` for details. + """ + def filter(self, predicate: Callable[[Any], Any]) -> Self: """Filter examples from data pipeline and keep only those who match ``predicate``. diff --git a/tests/unit/data/test_collater.py b/tests/unit/data/test_collater.py index 8bb08bd73..c2cc13c08 100644 --- a/tests/unit/data/test_collater.py +++ b/tests/unit/data/test_collater.py @@ -8,8 +8,8 @@ import torch from torch.nn.functional import pad -from fairseq2.data import CollateOptionsOverride, Collater -from tests.common import assert_close, assert_equal, device +from fairseq2.data import CollateOptionsOverride, Collater, read_sequence +from tests.common import assert_close, assert_equal, device, python_devel_only class TestCollater: @@ -378,3 +378,51 @@ def test_init_raises_error_when_pad_idx_is_none_and_pad_to_multiple_is_greater_t match=r"^`pad_idx` of the selector 'foo' must be set when `pad_to_multiple` is greater than 1\.$", ): Collater(overrides=[CollateOptionsOverride("foo", pad_to_multiple=2)]) + + +@pytest.mark.skipif(python_devel_only(), reason="fairseq2n 0.2.0") +@pytest.mark.parametrize("pad_to_multiple,pad_size", [(1, 0), (2, 0), (3, 2), (8, 4)]) +def test_collate_works_when_input_has_sequence_tensors( + pad_to_multiple: int, pad_size: int +) -> None: + bucket1 = [ + torch.full((4, 2), 0, device=device, dtype=torch.int64), + torch.full((4, 2), 1, device=device, dtype=torch.int64), + torch.full((4, 2), 2, device=device, dtype=torch.int64), + ] + + bucket2 = [ + [{"foo1": 0, "foo2": 1}, {"foo3": 2, "foo4": 3}], + [{"foo1": 4, "foo2": 5}, {"foo3": 6, "foo4": 7}], + [{"foo1": 8, "foo2": 9}, {"foo3": 0, "foo4": 1}], + ] + + expected1_seqs = torch.tensor( + [ + [[0, 0], [0, 0], [0, 0], [0, 0]], + [[1, 1], [1, 1], [1, 1], [1, 1]], + [[2, 2], [2, 2], [2, 2], [2, 2]], + ], + device=device, + dtype=torch.int64, + ) + expected1_seqs = pad(expected1_seqs, (0, 0, 0, pad_size), value=3) + expected1_seq_lens = torch.tensor([4, 4, 4], device=device, dtype=torch.int64) + + expected2 = [ + {"foo1": [0, 4, 8], "foo2": [1, 5, 9]}, + {"foo3": [2, 6, 0], "foo4": [3, 7, 1]}, + ] + + data = ( + read_sequence([bucket1, bucket2]) + .collate(pad_idx=3, pad_to_multiple=pad_to_multiple) + .and_return() + ) + output1, output2 = list(data) + + assert_close(output1["seqs"], expected1_seqs) + assert_equal(output1["seq_lens"], expected1_seq_lens) + assert output1["is_ragged"] == False + + assert output2 == expected2