-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce stop_at_shortest in sample and round_robin (#76)
* Add tensor.h to utils * Add sample_data_source class * Add sample method to data_pipeline * Add sample binding to data_pipeline * Add test_sample * Fix build errors * fix python lint errors * Add empty lines at end of c++ files * remove \n from runtime error regex match * Use float32 instead of float * Resolve nit comments * Remove set_seed from sampler * Fix document error * Handle generator state out of sample operator * Add circular data source * Use circular data_soruce in round_robin * fix circular data source * Add stop_on_shortest flag * Add stop_at_shortest flag to round_robin * Add stop_at_shortest round_robin test * Fix circular data source stop_at_shortest flag * Use circular data_source in sample op * Add up_sampling test * Rename circular_data_source to multi_data_source * Improve sample and round_robin doc * lint code * skip test that uses new API * nit changes * rename multi_data_source to composite_data_source * use std::exchange instead of = * resolve last nit comment
- Loading branch information
1 parent
e081284
commit 52b05c7
Showing
13 changed files
with
302 additions
and
146 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
// All rights reserved. | ||
// | ||
// This source code is licensed under the BSD-style license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#include "fairseq2n/data/composite_data_source.h" | ||
|
||
namespace fairseq2n::detail { | ||
|
||
composite_data_source::composite_data_source( | ||
std::vector<data_pipeline> &&pipelines, | ||
index_generator_fn &&index_gen_fn, | ||
bool stop_at_shortest) | ||
: pipelines_(std::move(pipelines)), | ||
next_index_gen_{std::move(index_gen_fn)}, | ||
stop_at_shortest_{stop_at_shortest} | ||
{ | ||
if (!stop_at_shortest) { | ||
is_epoch_done_ = std::vector<bool>(pipelines_.size(), false); | ||
buffer_ = std::vector<std::optional<data>>(pipelines_.size(), std::nullopt); | ||
} | ||
} | ||
|
||
std::optional<data> | ||
composite_data_source::next() | ||
{ | ||
if (stop_at_shortest_) // with this flag on, the operator is a simple iterator | ||
return pipelines_[next_index_gen_()].next(); | ||
|
||
// One or more data pipelines might be empty, so we have to keep looping | ||
std::optional<data> output{}; | ||
while (!output && !eod()) { | ||
auto pipeline_idx = next_index_gen_(); | ||
auto &maybe_example = buffer_[pipeline_idx]; | ||
|
||
if (!maybe_example) // init buffer at first call | ||
maybe_example = next_in_pipeline(pipeline_idx); | ||
|
||
output = std::exchange(maybe_example, next_in_pipeline(pipeline_idx)); | ||
} | ||
|
||
return output; | ||
} | ||
|
||
void | ||
composite_data_source::reset() | ||
{ | ||
for (data_pipeline &pipeline : pipelines_) | ||
pipeline.reset(); | ||
|
||
if (!stop_at_shortest_) { | ||
buffer_.assign(pipelines_.size(), std::nullopt); | ||
is_epoch_done_.assign(pipelines_.size(), false); | ||
is_eod_ = false; | ||
} | ||
} | ||
|
||
void | ||
composite_data_source::record_position(tape &t) const | ||
{ | ||
for (const data_pipeline &pipeline : pipelines_) | ||
pipeline.record_position(t); | ||
|
||
if (!stop_at_shortest_) { | ||
t.record(buffer_); | ||
t.record(is_epoch_done_); | ||
} | ||
} | ||
|
||
void | ||
composite_data_source::reload_position(tape &t) | ||
{ | ||
for (data_pipeline &pipeline : pipelines_) | ||
pipeline.reload_position(t); | ||
|
||
if (!stop_at_shortest_) { | ||
buffer_ = t.read<std::vector<std::optional<data>>>(); | ||
is_epoch_done_ = t.read<std::vector<bool>>(); | ||
is_eod_ = false; | ||
} | ||
} | ||
|
||
std::optional<data> | ||
composite_data_source::next_in_pipeline(std::size_t pipeline_idx) | ||
{ | ||
data_pipeline &pipeline = pipelines_[pipeline_idx]; | ||
|
||
std::optional<data> maybe_example = pipeline.next(); | ||
if (!maybe_example) { | ||
is_epoch_done_[pipeline_idx] = true; | ||
|
||
pipeline.reset(); | ||
|
||
// Circle back to the first example. | ||
maybe_example = pipeline.next(); | ||
} | ||
|
||
return maybe_example; | ||
} | ||
|
||
bool | ||
composite_data_source::eod() | ||
{ | ||
is_eod_ = is_eod_ || std::all_of( | ||
is_epoch_done_.begin(), is_epoch_done_.end(), [](bool b) | ||
{ | ||
return b; | ||
}); | ||
|
||
return is_eod_; | ||
} | ||
|
||
} // namespace fairseq2n::detail |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
// All rights reserved. | ||
// | ||
// This source code is licensed under the BSD-style license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
|
||
#include <vector> | ||
|
||
#include "fairseq2n/data/data_pipeline.h" | ||
|
||
|
||
using index_generator_fn = std::function<std::size_t()>; | ||
|
||
namespace fairseq2n::detail { | ||
|
||
class composite_data_source final : public data_source { | ||
public: | ||
explicit | ||
composite_data_source(std::vector<data_pipeline> &&pipelines, index_generator_fn &&index_gen_fn, bool stop_at_shortest); | ||
|
||
std::optional<data> | ||
next() override; | ||
|
||
void | ||
reset() override; | ||
|
||
void | ||
record_position(tape &t) const override; | ||
|
||
void | ||
reload_position(tape &t) override; | ||
|
||
private: | ||
std::optional<data> | ||
next_in_pipeline(std::size_t pipeline_idx); | ||
|
||
bool | ||
eod(); | ||
|
||
private: | ||
std::vector<data_pipeline> pipelines_; | ||
index_generator_fn next_index_gen_; | ||
std::vector<std::optional<data>> buffer_{}; | ||
std::vector<bool> is_epoch_done_; | ||
bool is_eod_ = false; | ||
bool stop_at_shortest_; | ||
}; | ||
|
||
} // namespace fairseq2n::detail |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.