Skip to content

Commit

Permalink
add .collate for .map(Collater)
Browse files Browse the repository at this point in the history
  • Loading branch information
gwenzek committed Oct 10, 2023
1 parent 199cf93 commit 4b71687
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 1 deletion.
27 changes: 27 additions & 0 deletions fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,33 @@ def_data_pipeline(py::module_ &data_module)
py::arg("bucket_sizes"),
py::arg("selector") = std::nullopt,
py::arg("drop_remainder") = false)
.def(
"collate",
[](
data_pipeline_builder &self,
std::optional<std::int64_t> maybe_pad_idx,
std::int64_t pad_to_multiple,
std::optional<std::vector<collate_options_override>> maybe_opt_overrides,
std::size_t num_parallel_calls) -> data_pipeline_builder &
{
auto opts = collate_options()
.maybe_pad_idx(maybe_pad_idx).pad_to_multiple(pad_to_multiple);

std::vector<collate_options_override> opt_overrides{};
if (maybe_opt_overrides)
opt_overrides = *std::move(maybe_opt_overrides);

map_fn f = collater(opts, std::move(opt_overrides));
element_mapper mapper{f, std::nullopt};

self = std::move(self).map(std::move(mapper), num_parallel_calls);

return self;
},
py::arg("pad_idx") = std::nullopt,
py::arg("pad_to_multiple") = 1,
py::arg("overrides") = std::nullopt,
py::arg("num_parallel_calls") = 1)
.def(
"filter",
[](data_pipeline_builder &self, predicate_fn fn) -> data_pipeline_builder &
Expand Down
11 changes: 11 additions & 0 deletions src/fairseq2/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,17 @@ def bucket_by_length(
) -> Self:
"""Combine examples of similar shape into batches."""

def collate(
self,
pad_idx: Optional[int] = None,
pad_to_multiple: int = 1,
overrides: Optional[Sequence["CollateOptionsOverride"]] = None,
) -> Self:
"""Concatenate a list of inputs into a single inputs.
See :py:class:`fairseq2.data.Collater` for details.
"""

def filter(self, predicate: Callable[[Any], Any]) -> Self:
"""Filter examples from data pipeline and keep only those who match
``predicate``.
Expand Down
49 changes: 48 additions & 1 deletion tests/unit/data/test_collater.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from torch.nn.functional import pad

from fairseq2.data import CollateOptionsOverride, Collater
from fairseq2.data import CollateOptionsOverride, Collater, read_sequence
from tests.common import assert_close, assert_equal, device


Expand Down Expand Up @@ -378,3 +378,50 @@ def test_init_raises_error_when_pad_idx_is_none_and_pad_to_multiple_is_greater_t
match=r"^`pad_idx` of the selector 'foo' must be set when `pad_to_multiple` is greater than 1\.$",
):
Collater(overrides=[CollateOptionsOverride("foo", pad_to_multiple=2)])


@pytest.mark.parametrize("pad_to_multiple,pad_size", [(1, 0), (2, 0), (3, 2), (8, 4)])
def test_collate_works_when_input_has_sequence_tensors(
pad_to_multiple: int, pad_size: int
) -> None:
bucket1 = [
torch.full((4, 2), 0, device=device, dtype=torch.int64),
torch.full((4, 2), 1, device=device, dtype=torch.int64),
torch.full((4, 2), 2, device=device, dtype=torch.int64),
]

bucket2 = [
[{"foo1": 0, "foo2": 1}, {"foo3": 2, "foo4": 3}],
[{"foo1": 4, "foo2": 5}, {"foo3": 6, "foo4": 7}],
[{"foo1": 8, "foo2": 9}, {"foo3": 0, "foo4": 1}],
]

expected1_seqs = torch.tensor(
[
[[0, 0], [0, 0], [0, 0], [0, 0]],
[[1, 1], [1, 1], [1, 1], [1, 1]],
[[2, 2], [2, 2], [2, 2], [2, 2]],
],
device=device,
dtype=torch.int64,
)
expected1_seqs = pad(expected1_seqs, (0, 0, 0, pad_size), value=3)
expected1_seq_lens = torch.tensor([4, 4, 4], device=device, dtype=torch.int64)

expected2 = [
{"foo1": [0, 4, 8], "foo2": [1, 5, 9]},
{"foo3": [2, 6, 0], "foo4": [3, 7, 1]},
]

data = (
read_sequence([bucket1, bucket2])
.collate(pad_idx=3, pad_to_multiple=pad_to_multiple)
.and_return()
)
output1, output2 = list(data)

assert_close(output1["seqs"], expected1_seqs)
assert_equal(output1["seq_lens"], expected1_seq_lens)
assert output1["is_ragged"] == False

assert output2 == expected2

0 comments on commit 4b71687

Please sign in to comment.