-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'facebookresearch:main' into img_processing
- Loading branch information
Showing
8 changed files
with
255 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
// All rights reserved. | ||
// | ||
// This source code is licensed under the BSD-style license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#include "fairseq2n/data/concat_data_source.h" | ||
#include <vector> | ||
|
||
namespace fairseq2n::detail { | ||
|
||
concat_data_source::concat_data_source(std::vector<data_pipeline> &&pipelines) | ||
: pipelines_(std::move(pipelines)) | ||
{} | ||
|
||
std::optional<data> | ||
concat_data_source::next() | ||
{ | ||
std::optional<data> d; | ||
for (auto &p : pipelines_) { | ||
d = p.next(); | ||
if (d) | ||
return d; | ||
} | ||
return {}; | ||
} | ||
|
||
void concat_data_source::reset() | ||
{ | ||
for (auto &pipeline : pipelines_) | ||
pipeline.reset(); | ||
} | ||
|
||
void concat_data_source::record_position(tape &t) const | ||
{ | ||
for (auto &pipeline : pipelines_) | ||
pipeline.record_position(t); | ||
} | ||
|
||
void concat_data_source::reload_position(tape &t) | ||
{ | ||
for (auto &pipeline : pipelines_) | ||
pipeline.reload_position(t); | ||
} | ||
|
||
} // namespace fairseq2n::detail | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
// All rights reserved. | ||
// | ||
// This source code is licensed under the BSD-style license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "fairseq2n/data/data_pipeline.h" | ||
#include "fairseq2n/data/data_source.h" | ||
|
||
namespace fairseq2n::detail { | ||
|
||
class concat_data_source final : public data_source { | ||
public: | ||
explicit | ||
concat_data_source( | ||
std::vector<data_pipeline> &&pipelines); | ||
|
||
std::optional<data> | ||
next() override; | ||
|
||
void | ||
reset() override; | ||
|
||
void | ||
record_position(tape &t) const override; | ||
|
||
void | ||
reload_position(tape &t) override; | ||
|
||
private: | ||
std::vector<data_pipeline> pipelines_; | ||
}; | ||
|
||
} // namespace fairseq2n::detail |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
|
||
from fairseq2.data import DataPipeline, DataPipelineError, read_sequence | ||
from fairseq2.data.text import read_text | ||
from tests.common import python_devel_only | ||
|
||
|
||
@pytest.mark.skipif( | ||
python_devel_only(), | ||
reason="New fairseq2n API in Python-only installation. Skipping till v0.2.", | ||
) | ||
class TestConcatOp: | ||
def test_op_works(self) -> None: | ||
pipeline1 = read_sequence([1, 2, 3, 4]).and_return() | ||
pipeline2 = read_sequence([5, 6, 7, 8]).and_return() | ||
|
||
pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return() | ||
|
||
for _ in range(2): | ||
assert list(pipeline) == [1, 2, 3, 4, 5, 6, 7, 8] | ||
pipeline.reset() | ||
|
||
def test_op_works_when_pipelines_are_empty(self) -> None: | ||
pipeline1 = read_sequence([]).and_return() | ||
pipeline2 = read_sequence([]).and_return() | ||
|
||
pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return() | ||
|
||
for _ in range(2): | ||
assert list(pipeline) == [] | ||
pipeline.reset() | ||
|
||
def test_op_works_when_pipelines_have_different_lengths(self) -> None: | ||
pipeline1 = read_sequence([1, 2, 3]).and_return() | ||
pipeline2 = read_sequence([4, 5]).and_return() | ||
|
||
pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return() | ||
|
||
for _ in range(2): | ||
assert list(pipeline) == [1, 2, 3, 4, 5] | ||
pipeline.reset() | ||
|
||
def test_op_raises_error_when_one_of_the_pipelines_is_broken(self) -> None: | ||
# Force a non-recoverable error. | ||
pipeline1 = read_text(pathname=" &^#").and_return() | ||
pipeline2 = read_text(pathname=" &^#").and_return() | ||
|
||
# Break the first pipeline. | ||
try: | ||
next(iter(pipeline1)) | ||
except DataPipelineError: | ||
assert pipeline1.is_broken | ||
|
||
with pytest.raises( | ||
ValueError, | ||
match=r"^At least one of the specified data pipelines is broken and cannot be concatenated\.$", | ||
): | ||
DataPipeline.concat([pipeline1, pipeline2]).and_return() | ||
|
||
def test_op_saves_and_restores_its_state(self) -> None: | ||
pipeline1 = read_sequence([1, 2, 3, 4]).and_return() | ||
pipeline2 = read_sequence([5, 6, 7, 8]).and_return() | ||
|
||
pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return() | ||
|
||
d = None | ||
|
||
it = iter(pipeline) | ||
|
||
# Move the the second example. | ||
for _ in range(2): | ||
d = next(it) | ||
|
||
assert d == 2 | ||
|
||
state_dict = pipeline.state_dict() | ||
|
||
# Read one more example before we roll back. | ||
d = next(it) | ||
|
||
assert d == 3 | ||
|
||
# Expected to roll back to the second example. | ||
pipeline.load_state_dict(state_dict) | ||
|
||
# Move to EOD. | ||
for _ in range(6): | ||
d = next(it) | ||
|
||
assert d == 8 | ||
|
||
state_dict = pipeline.state_dict() | ||
|
||
pipeline.reset() | ||
|
||
# Expected to be EOD. | ||
pipeline.load_state_dict(state_dict) | ||
|
||
with pytest.raises(StopIteration): | ||
next(iter(pipeline)) |