Skip to content

Commit

Permalink
add .collate for .map(Collater) (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
gwenzek authored Oct 10, 2023
1 parent 199cf93 commit 8deaba4
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 4 deletions.
26 changes: 26 additions & 0 deletions fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,32 @@ def_data_pipeline(py::module_ &data_module)
py::arg("bucket_sizes"),
py::arg("selector") = std::nullopt,
py::arg("drop_remainder") = false)
.def(
"collate",
[](
data_pipeline_builder &self,
std::optional<std::int64_t> maybe_pad_idx,
std::int64_t pad_to_multiple,
std::optional<std::vector<collate_options_override>> maybe_opt_overrides,
std::size_t num_parallel_calls) -> data_pipeline_builder &
{
auto opts = collate_options()
.maybe_pad_idx(maybe_pad_idx).pad_to_multiple(pad_to_multiple);

std::vector<collate_options_override> opt_overrides{};
if (maybe_opt_overrides)
opt_overrides = *std::move(maybe_opt_overrides);

map_fn f = collater(opts, std::move(opt_overrides));

self = std::move(self).map(std::move(f), num_parallel_calls);

return self;
},
py::arg("pad_idx") = std::nullopt,
py::arg("pad_to_multiple") = 1,
py::arg("overrides") = std::nullopt,
py::arg("num_parallel_calls") = 1)
.def(
"filter",
[](data_pipeline_builder &self, predicate_fn fn) -> data_pipeline_builder &
Expand Down
17 changes: 15 additions & 2 deletions src/fairseq2/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ class DataPipeline(Iterable[Any]):
The pipeline state can be persisted to the disk, allowing it to be resumed later.
It is a Python Iterable, but it also contains the iterator states.
Calling `iter` a second time while the first iterator is still being used
will segfault or worse.
Calling `iter` twice will create two iterators reading from the same dataloader,
and sharing the same state, so it will behave inconcistently.
"""

def __iter__(self) -> Iterator[Any]:
Expand Down Expand Up @@ -155,6 +156,18 @@ def bucket_by_length(
) -> Self:
"""Combine examples of similar shape into batches."""

def collate(
self,
pad_idx: Optional[int] = None,
pad_to_multiple: int = 1,
overrides: Optional[Sequence["CollateOptionsOverride"]] = None,
) -> Self:
"""Concatenate a list of inputs into a single inputs.
This is equivalent to calling `.map(Collater())`.
See :py:class:`fairseq2.data.Collater` for details.
"""

def filter(self, predicate: Callable[[Any], Any]) -> Self:
"""Filter examples from data pipeline and keep only those who match
``predicate``.
Expand Down
52 changes: 50 additions & 2 deletions tests/unit/data/test_collater.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch
from torch.nn.functional import pad

from fairseq2.data import CollateOptionsOverride, Collater
from tests.common import assert_close, assert_equal, device
from fairseq2.data import CollateOptionsOverride, Collater, read_sequence
from tests.common import assert_close, assert_equal, device, python_devel_only


class TestCollater:
Expand Down Expand Up @@ -378,3 +378,51 @@ def test_init_raises_error_when_pad_idx_is_none_and_pad_to_multiple_is_greater_t
match=r"^`pad_idx` of the selector 'foo' must be set when `pad_to_multiple` is greater than 1\.$",
):
Collater(overrides=[CollateOptionsOverride("foo", pad_to_multiple=2)])


@pytest.mark.skipif(python_devel_only(), reason="fairseq2n 0.2.0")
@pytest.mark.parametrize("pad_to_multiple,pad_size", [(1, 0), (2, 0), (3, 2), (8, 4)])
def test_collate_works_when_input_has_sequence_tensors(
pad_to_multiple: int, pad_size: int
) -> None:
bucket1 = [
torch.full((4, 2), 0, device=device, dtype=torch.int64),
torch.full((4, 2), 1, device=device, dtype=torch.int64),
torch.full((4, 2), 2, device=device, dtype=torch.int64),
]

bucket2 = [
[{"foo1": 0, "foo2": 1}, {"foo3": 2, "foo4": 3}],
[{"foo1": 4, "foo2": 5}, {"foo3": 6, "foo4": 7}],
[{"foo1": 8, "foo2": 9}, {"foo3": 0, "foo4": 1}],
]

expected1_seqs = torch.tensor(
[
[[0, 0], [0, 0], [0, 0], [0, 0]],
[[1, 1], [1, 1], [1, 1], [1, 1]],
[[2, 2], [2, 2], [2, 2], [2, 2]],
],
device=device,
dtype=torch.int64,
)
expected1_seqs = pad(expected1_seqs, (0, 0, 0, pad_size), value=3)
expected1_seq_lens = torch.tensor([4, 4, 4], device=device, dtype=torch.int64)

expected2 = [
{"foo1": [0, 4, 8], "foo2": [1, 5, 9]},
{"foo3": [2, 6, 0], "foo4": [3, 7, 1]},
]

data = (
read_sequence([bucket1, bucket2])
.collate(pad_idx=3, pad_to_multiple=pad_to_multiple)
.and_return()
)
output1, output2 = list(data)

assert_close(output1["seqs"], expected1_seqs)
assert_equal(output1["seq_lens"], expected1_seq_lens)
assert output1["is_ragged"] == False

assert output2 == expected2

0 comments on commit 8deaba4

Please sign in to comment.