Skip to content

Commit

Permalink
Add stop_on_shortest flag
Browse files Browse the repository at this point in the history
  • Loading branch information
najielhachem committed Oct 2, 2023
1 parent 73e348f commit a24baa0
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 6 deletions.
7 changes: 4 additions & 3 deletions fairseq2n/src/fairseq2n/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ data_pipeline::round_robin(std::vector<data_pipeline> pipelines)
data_pipeline_builder
data_pipeline::sample(
std::vector<data_pipeline> pipelines,
std::optional<std::vector<float32>> weights)
std::optional<std::vector<float32>> weights,
bool stop_at_shortest)
{
if (pipelines.empty())
throw_<std::invalid_argument>(
Expand All @@ -251,8 +252,8 @@ data_pipeline::sample(

auto tmp = std::make_shared<std::vector<data_pipeline>>(std::move(pipelines));

auto factory = [tmp, weights=std::move(weights.value())]() mutable {
return std::make_unique<sample_data_source>(std::move(*tmp), std::move(weights));
auto factory = [tmp, weights=std::move(weights.value()), stop_at_shortest]() mutable {
return std::make_unique<sample_data_source>(std::move(*tmp), std::move(weights), stop_at_shortest);
};

return data_pipeline_builder{std::move(factory)};
Expand Down
3 changes: 2 additions & 1 deletion fairseq2n/src/fairseq2n/data/data_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ class FAIRSEQ2_API data_pipeline {
static data_pipeline_builder
sample(
std::vector<data_pipeline> pipelines,
std::optional<std::vector<float>> weights = {});
std::optional<std::vector<float>> weights = {},
bool stop_at_shortest = true);

static data_pipeline_builder
constant(data example, std::optional<std::string> key = {});
Expand Down
2 changes: 1 addition & 1 deletion fairseq2n/src/fairseq2n/data/sample_data_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

namespace fairseq2n::detail {

sample_data_source::sample_data_source(std::vector<data_pipeline> &&pipelines, std::vector<float32> &&weights)
sample_data_source::sample_data_source(std::vector<data_pipeline> &&pipelines, std::vector<float32> &&weights, bool stop_at_shortest)

Check failure on line 19 in fairseq2n/src/fairseq2n/data/sample_data_source.cc

View workflow job for this annotation

GitHub Actions / Lint C++ / Lint

unused parameter 'stop_at_shortest' [clang-diagnostic-unused-parameter,-warnings-as-errors]

Check failure on line 19 in fairseq2n/src/fairseq2n/data/sample_data_source.cc

View workflow job for this annotation

GitHub Actions / Lint C++ / Lint

parameter 'stop_at_shortest' is unused [misc-unused-parameters,-warnings-as-errors]
: pipelines_(std::move(pipelines))
{
weights_ = make_tensor_from_vector(weights, { static_cast<std::int64_t>(pipelines_.size()) });
Expand Down
2 changes: 1 addition & 1 deletion fairseq2n/src/fairseq2n/data/sample_data_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace fairseq2n::detail {
class sample_data_source final : public data_source {
public:
explicit
sample_data_source(std::vector<data_pipeline> &&pipelines, std::vector<float32> &&weights);
sample_data_source(std::vector<data_pipeline> &&pipelines, std::vector<float32> &&weights, bool stop_at_shortest);

std::optional<data>
next() override;
Expand Down
3 changes: 3 additions & 0 deletions src/fairseq2/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,16 @@ def round_robin(pipelines: Sequence["DataPipeline"]) -> "DataPipelineBuilder":
def sample(
pipelines: Sequence["DataPipeline"],
weights: Optional[Sequence[float]] = None,
stop_at_shortest: bool = True,
) -> "DataPipelineBuilder":
"""Extract examples from ``pipelines`` by sampling based on ``weights``.
:param data_pipelines:
The data pipelines to sample from.
:param weights:
Desired distribution of pipelines. If None, use uniform distribution.
:param stop_at_shortest:
Flag to stop sampling when first pipeline reaches its end.
"""

@staticmethod
Expand Down

0 comments on commit a24baa0

Please sign in to comment.