Skip to content

Commit

Permalink
Merge branch 'facebookresearch:main' into img_processing
Browse files Browse the repository at this point in the history
  • Loading branch information
am831 authored Oct 18, 2023
2 parents 8a48b9b + e969622 commit 8d12ed8
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 3 deletions.
18 changes: 17 additions & 1 deletion fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,24 @@ def_data_pipeline(py::module_ &data_module)
return data_pipeline::count(start, std::move(key));
},
py::arg("start") = 0,
py::arg("key") = std::nullopt);
py::arg("key") = std::nullopt)
.def_static(
"concat",
[](std::vector<std::reference_wrapper<data_pipeline>> &refs)
{
std::vector<data_pipeline> pipelines{};

pipelines.reserve(refs.size());

std::transform(
refs.begin(), refs.end(), std::back_inserter(pipelines), [](auto &r) {
return std::move(r.get());
});

return data_pipeline::concat(std::move(pipelines));
},
py::arg("pipelines"));

// DataPipelineIterator
py::class_<data_pipeline_iterator>(m, "_DataPipelineIterator")
.def(
Expand Down
1 change: 1 addition & 0 deletions fairseq2n/src/fairseq2n/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ target_sources(fairseq2n
data/bucket_data_source.cc
data/byte_stream.cc
data/collater.cc
data/concat_data_source.cc
data/constant_data_source.cc
data/count_data_source.cc
data/data.cc
Expand Down
47 changes: 47 additions & 0 deletions fairseq2n/src/fairseq2n/data/concat_data_source.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "fairseq2n/data/concat_data_source.h"
#include <vector>

namespace fairseq2n::detail {

concat_data_source::concat_data_source(std::vector<data_pipeline> &&pipelines)
: pipelines_(std::move(pipelines))
{}

std::optional<data>
concat_data_source::next()
{
std::optional<data> d;
for (auto &p : pipelines_) {
d = p.next();
if (d)
return d;
}
return {};
}

void concat_data_source::reset()
{
for (auto &pipeline : pipelines_)
pipeline.reset();
}

void concat_data_source::record_position(tape &t) const
{
for (auto &pipeline : pipelines_)
pipeline.record_position(t);
}

void concat_data_source::reload_position(tape &t)
{
for (auto &pipeline : pipelines_)
pipeline.reload_position(t);
}

} // namespace fairseq2n::detail

40 changes: 40 additions & 0 deletions fairseq2n/src/fairseq2n/data/concat_data_source.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <memory>
#include <utility>
#include <vector>

#include "fairseq2n/data/data_pipeline.h"
#include "fairseq2n/data/data_source.h"

namespace fairseq2n::detail {

class concat_data_source final : public data_source {
public:
explicit
concat_data_source(
std::vector<data_pipeline> &&pipelines);

std::optional<data>
next() override;

void
reset() override;

void
record_position(tape &t) const override;

void
reload_position(tape &t) override;

private:
std::vector<data_pipeline> pipelines_;
};

} // namespace fairseq2n::detail
34 changes: 32 additions & 2 deletions fairseq2n/src/fairseq2n/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,27 @@
#include <system_error>
#include <utility>

#include "data_pipeline.h"
#include "fairseq2n/data/bucket_by_length_data_source.h"
#include "fairseq2n/data/bucket_data_source.h"
#include "fairseq2n/data/concat_data_source.h"
#include "fairseq2n/data/constant_data_source.h"
#include "fairseq2n/data/count_data_source.h"
#include "fairseq2n/data/detail/file_system.h"
#include "fairseq2n/data/filter_data_source.h"
#include "fairseq2n/data/list_data_source.h"
#include "fairseq2n/data/map_data_source.h"
#include "fairseq2n/data/prefetch_data_source.h"
#include "fairseq2n/data/take_data_source.h"
#include "fairseq2n/data/round_robin_data_source.h"
#include "fairseq2n/data/sample_data_source.h"
#include "fairseq2n/data/shard_data_source.h"
#include "fairseq2n/data/shuffle_data_source.h"
#include "fairseq2n/data/skip_data_source.h"
#include "fairseq2n/data/take_data_source.h"
#include "fairseq2n/data/tape.h"
#include "fairseq2n/data/yield_from_data_source.h"
#include "fairseq2n/data/zip_data_source.h"
#include "fairseq2n/data/zip_file_data_source.h"
#include "fairseq2n/data/detail/file_system.h"
#include "fairseq2n/detail/exception.h"

using namespace fairseq2n::detail;
Expand Down Expand Up @@ -283,6 +285,34 @@ data_pipeline::count(std::int64_t start, std::optional<std::string> key)
return data_pipeline_builder{std::move(factory)};
}

data_pipeline_builder
data_pipeline::concat(
std::vector<data_pipeline> pipelines)
{
if (pipelines.empty())
throw_<std::invalid_argument>(
"`pipelines` does not contain any elements. Can not concatenate from empty set.");

bool is_broken = std::any_of(
pipelines.begin(), pipelines.end(), [](const data_pipeline &pipeline)
{
return pipeline.is_broken();
});

if (is_broken)
throw_<std::invalid_argument>(
"At least one of the specified data pipelines is broken and cannot be concatenated.");

auto tmp = std::make_shared<std::vector<data_pipeline>>(std::move(pipelines));

auto factory = [tmp]() mutable
{
return std::make_unique<concat_data_source>(std::move(*tmp));
};

return data_pipeline_builder{std::move(factory)};
}

data_pipeline_builder
data_pipeline_builder::bucket(std::size_t bucket_size, bool drop_remainder) &&
{
Expand Down
3 changes: 3 additions & 0 deletions fairseq2n/src/fairseq2n/data/data_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class FAIRSEQ2_API data_pipeline {
static data_pipeline_builder
count(std::int64_t start = 0, std::optional<std::string> key = {});

static data_pipeline_builder
concat(std::vector<data_pipeline> pipelines);

private:
data_source_factory factory_{};
std::unique_ptr<data_source> source_{};
Expand Down
9 changes: 9 additions & 0 deletions src/fairseq2/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ def constant(example: Any, key: Optional[str] = None) -> "DataPipelineBuilder":
def count(start: int = 0, key: Optional[str] = None) -> "DataPipelineBuilder":
...

@staticmethod
def concat(pipelines: Sequence["DataPipeline"]) -> "DataPipelineBuilder":
"""Concatenate examples from ``pipelines``.
:param pipelines:
The data pipelines to concatenate.
"""
...

class DataPipelineBuilder:
"""API to create DataPipeline"""

Expand Down
106 changes: 106 additions & 0 deletions tests/unit/data/data_pipeline/test_concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

from fairseq2.data import DataPipeline, DataPipelineError, read_sequence
from fairseq2.data.text import read_text
from tests.common import python_devel_only


@pytest.mark.skipif(
python_devel_only(),
reason="New fairseq2n API in Python-only installation. Skipping till v0.2.",
)
class TestConcatOp:
def test_op_works(self) -> None:
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = read_sequence([5, 6, 7, 8]).and_return()

pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return()

for _ in range(2):
assert list(pipeline) == [1, 2, 3, 4, 5, 6, 7, 8]
pipeline.reset()

def test_op_works_when_pipelines_are_empty(self) -> None:
pipeline1 = read_sequence([]).and_return()
pipeline2 = read_sequence([]).and_return()

pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return()

for _ in range(2):
assert list(pipeline) == []
pipeline.reset()

def test_op_works_when_pipelines_have_different_lengths(self) -> None:
pipeline1 = read_sequence([1, 2, 3]).and_return()
pipeline2 = read_sequence([4, 5]).and_return()

pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return()

for _ in range(2):
assert list(pipeline) == [1, 2, 3, 4, 5]
pipeline.reset()

def test_op_raises_error_when_one_of_the_pipelines_is_broken(self) -> None:
# Force a non-recoverable error.
pipeline1 = read_text(pathname=" &^#").and_return()
pipeline2 = read_text(pathname=" &^#").and_return()

# Break the first pipeline.
try:
next(iter(pipeline1))
except DataPipelineError:
assert pipeline1.is_broken

with pytest.raises(
ValueError,
match=r"^At least one of the specified data pipelines is broken and cannot be concatenated\.$",
):
DataPipeline.concat([pipeline1, pipeline2]).and_return()

def test_op_saves_and_restores_its_state(self) -> None:
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = read_sequence([5, 6, 7, 8]).and_return()

pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return()

d = None

it = iter(pipeline)

# Move the the second example.
for _ in range(2):
d = next(it)

assert d == 2

state_dict = pipeline.state_dict()

# Read one more example before we roll back.
d = next(it)

assert d == 3

# Expected to roll back to the second example.
pipeline.load_state_dict(state_dict)

# Move to EOD.
for _ in range(6):
d = next(it)

assert d == 8

state_dict = pipeline.state_dict()

pipeline.reset()

# Expected to be EOD.
pipeline.load_state_dict(state_dict)

with pytest.raises(StopIteration):
next(iter(pipeline))

0 comments on commit 8d12ed8

Please sign in to comment.