Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concatenate method for DataPipeline class #84

Merged
merged 33 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c925bce
concat function
am831 Oct 3, 2023
3d3370e
expose cat to python
am831 Oct 3, 2023
19e98da
add method declaration, add data_source file
am831 Oct 4, 2023
2bf2ce0
data_source file
am831 Oct 4, 2023
d846421
fix cat method in pybinding
am831 Oct 4, 2023
8ea7c72
Merge branch 'facebookresearch:main' into concat
am831 Oct 4, 2023
6caf742
Merge branch 'concat' of https://github.com/am831/fairseq2 into concat
am831 Oct 4, 2023
8a2f2b5
implement next, reload and concatenate
am831 Oct 4, 2023
65aa1a1
Merge branch 'main' into concat
am831 Oct 5, 2023
025c4ff
fix parameters for cat
am831 Oct 6, 2023
c260b8d
add cat_data_source to CMakeList
am831 Oct 6, 2023
e3a20bc
fix lint issues
am831 Oct 6, 2023
05b0aaf
fix type mismatch
am831 Oct 6, 2023
200fc56
fixed build errors - still some logic errors
am831 Oct 6, 2023
e6616ea
fixed build errors
am831 Oct 6, 2023
08ed63e
Merge branch 'concat' of https://github.com/am831/fairseq2 into concat
am831 Oct 6, 2023
6f5f485
Merge branch 'facebookresearch:main' into concat
am831 Oct 6, 2023
5212130
unit tests
am831 Oct 7, 2023
62ffb53
Merge branch 'facebookresearch:main' into concat
am831 Oct 9, 2023
a5d0185
Merge branch 'main' into concat
am831 Oct 9, 2023
4609f38
Merge branch 'facebookresearch:main' into concat
am831 Oct 10, 2023
cac280f
fix concat parameter and implementation
am831 Oct 11, 2023
8b3f08c
Merge branch 'main' into concat
am831 Oct 11, 2023
d49b1fd
remove renamed files
am831 Oct 11, 2023
7df592b
fix unit tests, remove concatenate func
am831 Oct 11, 2023
c1ce524
format
am831 Oct 11, 2023
e7455a0
Merge branch 'facebookresearch:main' into concat
am831 Oct 16, 2023
a2173b2
Merge branch 'facebookresearch:main' into concat
am831 Oct 17, 2023
03ac1ab
fix bug and syntax issues
am831 Oct 18, 2023
3b91ceb
Merge branch 'concat' of https://github.com/am831/fairseq2 into concat
am831 Oct 18, 2023
18e472b
fix regex pattern
am831 Oct 18, 2023
f0ecd29
format with black
am831 Oct 18, 2023
3974802
Merge branch 'facebookresearch:main' into concat
am831 Oct 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,24 @@ def_data_pipeline(py::module_ &data_module)
return data_pipeline::count(start, std::move(key));
},
py::arg("start") = 0,
py::arg("key") = std::nullopt);
py::arg("key") = std::nullopt)
.def_static(
"concat",
[](std::vector<std::reference_wrapper<data_pipeline>> &refs)
{
std::vector<data_pipeline> pipelines{};

pipelines.reserve(refs.size());

std::transform(
refs.begin(), refs.end(), std::back_inserter(pipelines), [](auto &r) {
return std::move(r.get());
});

return data_pipeline::concat(std::move(pipelines));
},
py::arg("pipelines"));

// DataPipelineIterator
py::class_<data_pipeline_iterator>(m, "_DataPipelineIterator")
.def(
Expand Down
1 change: 1 addition & 0 deletions fairseq2n/src/fairseq2n/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ target_sources(fairseq2n
data/bucket_data_source.cc
data/byte_stream.cc
data/collater.cc
data/concat_data_source.cc
data/constant_data_source.cc
data/count_data_source.cc
data/data.cc
Expand Down
49 changes: 49 additions & 0 deletions fairseq2n/src/fairseq2n/data/concat_data_source.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "fairseq2n/data/concat_data_source.h"
#include <vector>

namespace fairseq2n::detail {

concat_data_source::concat_data_source(
am831 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<data_pipeline> &&pipelines)
: pipelines_{std::move(pipelines)}
am831 marked this conversation as resolved.
Show resolved Hide resolved
{}

std::optional<data>
concat_data_source::next()
{
std::optional<data> d;
for (auto &p : pipelines_) {
d = p.next();
if (d) {
am831 marked this conversation as resolved.
Show resolved Hide resolved
return d;
}
}
return {};
}

void concat_data_source::reset()
{
for (auto &pipeline : pipelines_)
pipeline.reset();
}

void concat_data_source::record_position(tape &t) const
{
for (auto &pipeline : pipelines_)
pipeline.record_position(t);
}

void concat_data_source::reload_position(tape &t)
{
for (auto &pipeline : pipelines_)
pipeline.reload_position(t);
}

} // namespace fairseq2n::detail

40 changes: 40 additions & 0 deletions fairseq2n/src/fairseq2n/data/concat_data_source.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <memory>
#include <utility>
#include <vector>

#include "fairseq2n/data/data_pipeline.h"
#include "fairseq2n/data/data_source.h"

namespace fairseq2n::detail {

class concat_data_source final : public data_source {
public:
explicit
concat_data_source(
std::vector<data_pipeline> &&pipelines);

std::optional<data>
next() override;

void
reset() override;

void
record_position(tape &t) const override;

void
reload_position(tape &t) override;

private:
std::vector<data_pipeline> pipelines_;
};

} // namespace fairseq2n::detail
34 changes: 32 additions & 2 deletions fairseq2n/src/fairseq2n/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,27 @@
#include <system_error>
#include <utility>

#include "data_pipeline.h"
#include "fairseq2n/data/bucket_by_length_data_source.h"
#include "fairseq2n/data/bucket_data_source.h"
#include "fairseq2n/data/concat_data_source.h"
#include "fairseq2n/data/constant_data_source.h"
#include "fairseq2n/data/count_data_source.h"
#include "fairseq2n/data/detail/file_system.h"
#include "fairseq2n/data/filter_data_source.h"
#include "fairseq2n/data/list_data_source.h"
#include "fairseq2n/data/map_data_source.h"
#include "fairseq2n/data/prefetch_data_source.h"
#include "fairseq2n/data/take_data_source.h"
#include "fairseq2n/data/round_robin_data_source.h"
#include "fairseq2n/data/sample_data_source.h"
#include "fairseq2n/data/shard_data_source.h"
#include "fairseq2n/data/shuffle_data_source.h"
#include "fairseq2n/data/skip_data_source.h"
#include "fairseq2n/data/take_data_source.h"
#include "fairseq2n/data/tape.h"
#include "fairseq2n/data/yield_from_data_source.h"
#include "fairseq2n/data/zip_data_source.h"
#include "fairseq2n/data/zip_file_data_source.h"
#include "fairseq2n/data/detail/file_system.h"
#include "fairseq2n/detail/exception.h"

using namespace fairseq2n::detail;
Expand Down Expand Up @@ -283,6 +285,34 @@ data_pipeline::count(std::int64_t start, std::optional<std::string> key)
return data_pipeline_builder{std::move(factory)};
}

data_pipeline_builder
data_pipeline::concat(
std::vector<data_pipeline> pipelines)
{
if (pipelines.empty())
throw_<std::invalid_argument>(
"`pipelines` does not contain any elements. Can not concatenate from empty set.");

bool is_broken = std::any_of(
pipelines.begin(), pipelines.end(), [](const data_pipeline &pipeline)
{
return pipeline.is_broken();
});

if (is_broken)
throw_<std::invalid_argument>(
"At least one of the specified data pipelines is broken and cannot be used in concat.");
am831 marked this conversation as resolved.
Show resolved Hide resolved

auto tmp = std::make_shared<std::vector<data_pipeline>>(std::move(pipelines));

auto factory = [tmp]() mutable
{
return std::make_unique<concat_data_source>(std::move(*tmp));
};

return data_pipeline_builder{std::move(factory)};
}

data_pipeline_builder
data_pipeline_builder::bucket(std::size_t bucket_size, bool drop_remainder) &&
{
Expand Down
3 changes: 3 additions & 0 deletions fairseq2n/src/fairseq2n/data/data_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class FAIRSEQ2_API data_pipeline {
static data_pipeline_builder
count(std::int64_t start = 0, std::optional<std::string> key = {});

static data_pipeline_builder
concat(std::vector<data_pipeline> pipelines);

private:
data_source_factory factory_{};
std::unique_ptr<data_source> source_{};
Expand Down
9 changes: 9 additions & 0 deletions src/fairseq2/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ def constant(example: Any, key: Optional[str] = None) -> "DataPipelineBuilder":
def count(start: int = 0, key: Optional[str] = None) -> "DataPipelineBuilder":
...

@staticmethod
def concat(pipelines: Sequence["DataPipeline"]) -> "DataPipelineBuilder":
"""Concatenate examples from ``pipelines``.

:param pipelines:
The data pipelines to concatenate.
"""
...

class DataPipelineBuilder:
"""API to create DataPipeline"""

Expand Down
101 changes: 101 additions & 0 deletions tests/unit/data/data_pipeline/test_conat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

from fairseq2.data import DataPipeline, DataPipelineError, read_sequence
from fairseq2.data.text import read_text


class TestConcatOp:
def test_op_works(self) -> None:
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = read_sequence([5, 6, 7, 8]).and_return()

pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return()

for _ in range(2):
assert list(pipeline) == [1, 2, 3, 4, 5, 6, 7, 8]
pipeline.reset()

def test_op_works_when_pipelines_are_empty(self) -> None:
pipeline1 = read_sequence([]).and_return()
pipeline2 = read_sequence([]).and_return()

pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return()

for _ in range(2):
assert list(pipeline) == []
pipeline.reset()

def test_op_works_when_pipelines_have_different_lengths(self) -> None:
pipeline1 = read_sequence([1, 2, 3]).and_return()
pipeline2 = read_sequence([4, 5]).and_return()

pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return()

for _ in range(2):
assert list(pipeline) == [1, 2, 3, 4, 5]
pipeline.reset()

def test_op_raises_error_when_one_of_the_pipelines_is_broken(self) -> None:
# Force a non-recoverable error.
pipeline1 = read_text(pathname=" &^#").and_return()
pipeline2 = read_text(pathname=" &^#").and_return()

# Break the first pipeline.
try:
next(iter(pipeline1))
except DataPipelineError:
assert pipeline1.is_broken

with pytest.raises(
ValueError,
match=r"^At least one of the specified data pipelines is broken and cannot be used in concat\.$",
):
DataPipeline.concat([pipeline1, pipeline2]).and_return()

def test_op_saves_and_restores_its_state(self) -> None:
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = read_sequence([5, 6, 7, 8]).and_return()

pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return()

d = None

it = iter(pipeline)

# Move the the second example.
for _ in range(2):
d = next(it)

assert d == 2

state_dict = pipeline.state_dict()

# Read one more example before we roll back.
d = next(it)

assert d == 3

# Expected to roll back to the second example.
pipeline.load_state_dict(state_dict)

# Move to EOD.
for _ in range(6):
d = next(it)

assert d == 8

state_dict = pipeline.state_dict()

pipeline.reset()

# Expected to be EOD.
pipeline.load_state_dict(state_dict)

with pytest.raises(StopIteration):
next(iter(pipeline))