Skip to content

Commit

Permalink
Introduce stop_at_shortest in sample and round_robin (#76)
Browse files Browse the repository at this point in the history
* Add tensor.h to utils

* Add sample_data_source class

* Add sample method to data_pipeline

* Add sample binding to data_pipeline

* Add test_sample

* Fix build errors

* fix python lint errors

* Add empty lines at end of c++ files

* remove \n from runtime error regex match

* Use float32 instead of float

* Resolve nit comments

* Remove set_seed from sampler

* Fix document error

* Handle generator state out of sample operator

* Add circular data source

* Use circular data_soruce in round_robin

* fix circular data source

* Add stop_on_shortest flag

* Add stop_at_shortest flag to round_robin

* Add stop_at_shortest round_robin test

* Fix circular data source stop_at_shortest flag

* Use circular data_source in sample op

* Add up_sampling test

* Rename circular_data_source to multi_data_source

* Improve sample and round_robin doc

* lint code

* skip test that uses new API

* nit changes

* rename multi_data_source to composite_data_source

* use std::exchange instead of =

* resolve last nit comment
  • Loading branch information
najielhachem authored Oct 5, 2023
1 parent e081284 commit 52b05c7
Show file tree
Hide file tree
Showing 13 changed files with 302 additions and 146 deletions.
17 changes: 11 additions & 6 deletions fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def_data_pipeline(py::module_ &data_module)
py::arg("disable_parallelism") = false)
.def_static(
"round_robin",
[](std::vector<std::reference_wrapper<data_pipeline>> &refs)
[](
std::vector<std::reference_wrapper<data_pipeline>> &refs,
bool stop_at_shortest)
{
std::vector<data_pipeline> pipelines{};

Expand All @@ -315,14 +317,16 @@ def_data_pipeline(py::module_ &data_module)
return std::move(r.get());
});

return data_pipeline::round_robin(std::move(pipelines));
return data_pipeline::round_robin(std::move(pipelines), stop_at_shortest);
},
py::arg("pipelines"))
py::arg("pipelines"),
py::arg("stop_at_shortest") = false)
.def_static(
"sample",
[](
std::vector<std::reference_wrapper<data_pipeline>> &refs,
std::optional<std::vector<float>> weights)
std::optional<std::vector<float>> weights,
bool stop_at_shortest)
{
std::vector<data_pipeline> pipelines{};

Expand All @@ -334,10 +338,11 @@ def_data_pipeline(py::module_ &data_module)
});

return data_pipeline::sample(
std::move(pipelines), std::move(weights));
std::move(pipelines), std::move(weights), stop_at_shortest);
},
py::arg("pipelines"),
py::arg("weights") = std::nullopt)
py::arg("weights") = std::nullopt,
py::arg("stop_at_shortest") = false)
.def_static(
"constant",
[](data example, std::optional<std::string> key)
Expand Down
1 change: 1 addition & 0 deletions fairseq2n/src/fairseq2n/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ target_sources(fairseq2n
data/list_data_source.cc
data/map_data_source.cc
data/memory_stream.cc
data/composite_data_source.cc
data/prefetch_data_source.cc
data/py.cc
data/record_reader.cc
Expand Down
114 changes: 114 additions & 0 deletions fairseq2n/src/fairseq2n/data/composite_data_source.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "fairseq2n/data/composite_data_source.h"

namespace fairseq2n::detail {

composite_data_source::composite_data_source(
std::vector<data_pipeline> &&pipelines,
index_generator_fn &&index_gen_fn,
bool stop_at_shortest)
: pipelines_(std::move(pipelines)),
next_index_gen_{std::move(index_gen_fn)},
stop_at_shortest_{stop_at_shortest}
{
if (!stop_at_shortest) {
is_epoch_done_ = std::vector<bool>(pipelines_.size(), false);
buffer_ = std::vector<std::optional<data>>(pipelines_.size(), std::nullopt);
}
}

std::optional<data>
composite_data_source::next()
{
if (stop_at_shortest_) // with this flag on, the operator is a simple iterator
return pipelines_[next_index_gen_()].next();

// One or more data pipelines might be empty, so we have to keep looping
std::optional<data> output{};
while (!output && !eod()) {
auto pipeline_idx = next_index_gen_();
auto &maybe_example = buffer_[pipeline_idx];

if (!maybe_example) // init buffer at first call
maybe_example = next_in_pipeline(pipeline_idx);

output = std::exchange(maybe_example, next_in_pipeline(pipeline_idx));
}

return output;
}

void
composite_data_source::reset()
{
for (data_pipeline &pipeline : pipelines_)
pipeline.reset();

if (!stop_at_shortest_) {
buffer_.assign(pipelines_.size(), std::nullopt);
is_epoch_done_.assign(pipelines_.size(), false);
is_eod_ = false;
}
}

void
composite_data_source::record_position(tape &t) const
{
for (const data_pipeline &pipeline : pipelines_)
pipeline.record_position(t);

if (!stop_at_shortest_) {
t.record(buffer_);
t.record(is_epoch_done_);
}
}

void
composite_data_source::reload_position(tape &t)
{
for (data_pipeline &pipeline : pipelines_)
pipeline.reload_position(t);

if (!stop_at_shortest_) {
buffer_ = t.read<std::vector<std::optional<data>>>();
is_epoch_done_ = t.read<std::vector<bool>>();
is_eod_ = false;
}
}

std::optional<data>
composite_data_source::next_in_pipeline(std::size_t pipeline_idx)
{
data_pipeline &pipeline = pipelines_[pipeline_idx];

std::optional<data> maybe_example = pipeline.next();
if (!maybe_example) {
is_epoch_done_[pipeline_idx] = true;

pipeline.reset();

// Circle back to the first example.
maybe_example = pipeline.next();
}

return maybe_example;
}

bool
composite_data_source::eod()
{
is_eod_ = is_eod_ || std::all_of(
is_epoch_done_.begin(), is_epoch_done_.end(), [](bool b)
{
return b;
});

return is_eod_;
}

} // namespace fairseq2n::detail
51 changes: 51 additions & 0 deletions fairseq2n/src/fairseq2n/data/composite_data_source.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <vector>

#include "fairseq2n/data/data_pipeline.h"


using index_generator_fn = std::function<std::size_t()>;

namespace fairseq2n::detail {

class composite_data_source final : public data_source {
public:
explicit
composite_data_source(std::vector<data_pipeline> &&pipelines, index_generator_fn &&index_gen_fn, bool stop_at_shortest);

std::optional<data>
next() override;

void
reset() override;

void
record_position(tape &t) const override;

void
reload_position(tape &t) override;

private:
std::optional<data>
next_in_pipeline(std::size_t pipeline_idx);

bool
eod();

private:
std::vector<data_pipeline> pipelines_;
index_generator_fn next_index_gen_;
std::vector<std::optional<data>> buffer_{};
std::vector<bool> is_epoch_done_;
bool is_eod_ = false;
bool stop_at_shortest_;
};

} // namespace fairseq2n::detail
15 changes: 9 additions & 6 deletions fairseq2n/src/fairseq2n/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ data_pipeline::zip(
}

data_pipeline_builder
data_pipeline::round_robin(std::vector<data_pipeline> pipelines)
data_pipeline::round_robin(
std::vector<data_pipeline> pipelines,
bool stop_at_shortest)
{
bool is_broken = std::any_of(
pipelines.begin(), pipelines.end(), [](const data_pipeline &pipeline)
Expand All @@ -217,9 +219,9 @@ data_pipeline::round_robin(std::vector<data_pipeline> pipelines)

auto tmp = std::make_shared<std::vector<data_pipeline>>(std::move(pipelines));

auto factory = [tmp]() mutable
auto factory = [tmp, stop_at_shortest]() mutable
{
return std::make_unique<round_robin_data_source>(std::move(*tmp));
return std::make_unique<round_robin_data_source>(std::move(*tmp), stop_at_shortest);
};

return data_pipeline_builder{std::move(factory)};
Expand All @@ -228,7 +230,8 @@ data_pipeline::round_robin(std::vector<data_pipeline> pipelines)
data_pipeline_builder
data_pipeline::sample(
std::vector<data_pipeline> pipelines,
std::optional<std::vector<float32>> weights)
std::optional<std::vector<float32>> weights,
bool stop_at_shortest)
{
if (pipelines.empty())
throw_<std::invalid_argument>(
Expand All @@ -251,8 +254,8 @@ data_pipeline::sample(

auto tmp = std::make_shared<std::vector<data_pipeline>>(std::move(pipelines));

auto factory = [tmp, weights=std::move(weights.value())]() mutable {
return std::make_unique<sample_data_source>(std::move(*tmp), std::move(weights));
auto factory = [tmp, weights=std::move(weights.value()), stop_at_shortest]() mutable {
return std::make_unique<sample_data_source>(std::move(*tmp), std::move(weights), stop_at_shortest);
};

return data_pipeline_builder{std::move(factory)};
Expand Down
7 changes: 5 additions & 2 deletions fairseq2n/src/fairseq2n/data/data_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,15 @@ class FAIRSEQ2_API data_pipeline {
bool disable_parallelism = false);

static data_pipeline_builder
round_robin(std::vector<data_pipeline> pipelines);
round_robin(
std::vector<data_pipeline> pipelines,
bool stop_at_shortest = false);

static data_pipeline_builder
sample(
std::vector<data_pipeline> pipelines,
std::optional<std::vector<float>> weights = {});
std::optional<std::vector<float>> weights = {},
bool stop_at_shortest = false);

static data_pipeline_builder
constant(data example, std::optional<std::string> key = {});
Expand Down
Loading

0 comments on commit 52b05c7

Please sign in to comment.