Skip to content

Commit

Permalink
Add implementation of try_put_and_wait feature for continue nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
dnmokhov committed Jul 20, 2024
1 parent 4029091 commit 324ac2d
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ on:
branches: [master]

pull_request:
branches: [master, dev/kboyarinov/try_put_and_wait_production]
branches: [master, dev/kboyarinov/try_put_and_wait_production, dev/kboyarinov/try_put_and_wait_function_node]
types:
- opened
- synchronize
Expand Down
6 changes: 6 additions & 0 deletions include/oneapi/tbb/detail/_flow_graph_body_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,12 @@ class threshold_regulator<T, continue_msg, void> : public continue_receiver, no_
return my_node->decrement_counter( 1 );
}

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
graph_task* execute(const message_metainfo&) override {
return execute();
}
#endif

protected:

graph& graph_reference() const override {
Expand Down
21 changes: 15 additions & 6 deletions include/oneapi/tbb/detail/_flow_graph_node_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,19 +722,27 @@ class continue_input : public continue_receiver {

virtual broadcast_cache<output_type > &successors() = 0;

friend class apply_body_task_bypass< class_type, continue_msg >;
friend class apply_body_task_bypass< class_type, continue_msg __TBB_FLOW_GRAPH_METAINFO_ARG(graph_task_with_message_waiters) >;

//! Applies the body to the provided input
graph_task* apply_body_bypass( input_type __TBB_FLOW_GRAPH_METAINFO_ARG(const message_metainfo&) ) {
graph_task* apply_body_bypass( input_type __TBB_FLOW_GRAPH_METAINFO_ARG(const message_metainfo& metainfo) ) {
// There is an extra copied needed to capture the
// body execution without the try_put
fgt_begin_body( my_body );
output_type v = (*my_body)( continue_msg() );
fgt_end_body( my_body );
return successors().try_put_task( v );
return successors().try_put_task( v __TBB_FLOW_GRAPH_METAINFO_ARG(metainfo) );
}

graph_task* execute() override {
#if !__TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
message_metainfo metainfo{};
#else
return execute(message_metainfo{});
}

graph_task* execute(const message_metainfo& metainfo) override {
#endif
if(!is_graph_active(my_graph_ref)) {
return nullptr;
}
Expand All @@ -746,12 +754,13 @@ class continue_input : public continue_receiver {
#if _MSC_VER && !__INTEL_COMPILER
#pragma warning (pop)
#endif
return apply_body_bypass( continue_msg() __TBB_FLOW_GRAPH_METAINFO_ARG(message_metainfo{}) );
return apply_body_bypass( continue_msg() __TBB_FLOW_GRAPH_METAINFO_ARG(metainfo) );
}
else {
d1::small_object_allocator allocator{};
typedef apply_body_task_bypass<class_type, continue_msg> task_type;
graph_task* t = allocator.new_object<task_type>( graph_reference(), allocator, *this, continue_msg(), my_priority );
typedef apply_body_task_bypass<class_type, continue_msg __TBB_FLOW_GRAPH_METAINFO_ARG(graph_task_with_message_waiters)> task_type;
graph_task* t = allocator.new_object<task_type>( graph_reference(), allocator, *this, continue_msg(),
my_priority __TBB_FLOW_GRAPH_METAINFO_ARG(metainfo));
return t;
}
}
Expand Down
21 changes: 17 additions & 4 deletions include/oneapi/tbb/flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,23 +385,33 @@ class continue_receiver : public receiver< continue_msg > {
template< typename R, typename B > friend class run_and_put_task;
template<typename X, typename Y> friend class broadcast_cache;
template<typename X, typename Y> friend class round_robin_cache;

private:
// execute body is supposed to be too small to create a task for.
graph_task* try_put_task( const input_type & ) override {
graph_task* try_put_task_impl( const input_type& __TBB_FLOW_GRAPH_METAINFO_ARG(const message_metainfo& metainfo) ) {
{
spin_mutex::scoped_lock l(my_mutex);
if ( ++my_current_count < my_predecessor_count )
return SUCCESSFULLY_ENQUEUED;
else
my_current_count = 0;
}
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
graph_task* res = execute(metainfo);
#else
graph_task* res = execute();
#endif
return res? res : SUCCESSFULLY_ENQUEUED;
}

protected:
graph_task* try_put_task( const input_type& input ) override {
return try_put_task_impl(input __TBB_FLOW_GRAPH_METAINFO_ARG(message_metainfo{}));
}

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
// TODO: add metainfo support for continue_receiver
graph_task* try_put_task( const input_type& input, const message_metainfo& ) override {
return try_put_task(input);
graph_task* try_put_task( const input_type& input, const message_metainfo& metainfo ) override {
return try_put_task_impl(input, metainfo);
}
#endif

Expand All @@ -425,6 +435,9 @@ class continue_receiver : public receiver< continue_msg > {
/** This should be very fast or else spawn a task. This is
called while the sender is blocked in the try_put(). */
virtual graph_task* execute() = 0;
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
virtual graph_task* execute(const message_metainfo& metainfo) = 0;
#endif
template<typename TT, typename M> friend class successor_cache;
bool is_continue_receiver() override { return true; }

Expand Down
125 changes: 125 additions & 0 deletions test/tbb/test_continue_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,124 @@ void test_successor_cache_specialization() {
"Wrong number of messages is passed via continue_node");
}

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
void test_try_put_and_wait_default() {
tbb::task_arena arena(1);

arena.execute([&]{
tbb::flow::graph g;

int processed_items = 0;

tbb::flow::continue_node<tbb::flow::continue_msg>* start_node = nullptr;

tbb::flow::continue_node<tbb::flow::continue_msg> cont(g,
[&](tbb::flow::continue_msg) noexcept {
static bool put_ten_msgs = true;
if (put_ten_msgs) {
for (std::size_t i = 0; i < 10; ++i) {
start_node->try_put(tbb::flow::continue_msg{});
}
put_ten_msgs = false;
}
});

start_node = &cont;

tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight> writer(g,
[&](tbb::flow::continue_msg) noexcept {
++processed_items;
});

tbb::flow::make_edge(cont, writer);

cont.try_put_and_wait(tbb::flow::continue_msg{});

// Only 1 item should be processed, with the additional 10 items having been spawned
CHECK_MESSAGE(processed_items == 1, "Unexpected items processing");

g.wait_for_all();

// The additional 10 items should be processed
CHECK_MESSAGE(processed_items == 11, "Unexpected items processing");
});
}

void test_try_put_and_wait_lightweight() {
tbb::task_arena arena(1);

arena.execute([&]{
tbb::flow::graph g;

std::vector<int> start_work_items;
std::vector<int> processed_items;
std::vector<int> new_work_items;

int wait_message = 10;

for (int i = 0; i < wait_message; ++i) {
start_work_items.emplace_back(i);
if (i != 0) {
new_work_items.emplace_back(i + 10);
}
}

tbb::flow::continue_node<int, tbb::flow::lightweight>* start_node = nullptr;

tbb::flow::continue_node<int, tbb::flow::lightweight> cont(g,
[&](tbb::flow::continue_msg) noexcept {
static int counter = 0;
int i = counter++;
if (i == wait_message) {
for (int item : new_work_items) {
(void)item;
start_node->try_put(tbb::flow::continue_msg{});
}
}
return i;
});

start_node = &cont;

tbb::flow::function_node<int, int, tbb::flow::lightweight> writer(g, tbb::flow::unlimited,
[&](int input) noexcept {
processed_items.emplace_back(input);
return 0;
});

tbb::flow::make_edge(cont, writer);

for (int i = 0; i < wait_message; ++i) {
cont.try_put(tbb::flow::continue_msg{});
}

cont.try_put_and_wait(tbb::flow::continue_msg{});

std::size_t check_index = 0;

// For lightweight continue_node, start_work_items are expected to be processed first
// while putting items into the first node
for (auto item : start_work_items) {
CHECK_MESSAGE(processed_items[check_index++] == item, "Unexpected items processing");
}

// If the node is unlimited, it should process new_work_items immediately while processing the wait_message
for (auto item : new_work_items) {
CHECK_MESSAGE(processed_items[check_index++] == item, "Unexpected items processing");
}
// wait_message would be processed only after new_work_items
CHECK_MESSAGE(processed_items[check_index++] == wait_message, "Unexpected items processing");

g.wait_for_all();
});
}

void test_try_put_and_wait() {
test_try_put_and_wait_default();
test_try_put_and_wait_lightweight();
}
#endif // __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT

//! Test concurrent continue_node for correctness
//! \brief \ref error_guessing
TEST_CASE("Concurrency testing") {
Expand Down Expand Up @@ -418,3 +536,10 @@ TEST_CASE("constraints for continue_node body") {
static_assert(!can_call_continue_node_ctor<output_type, WrongReturnOperatorRoundBrackets<output_type>>);
}
#endif // __TBB_CPP20_CONCEPTS_PRESENT

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
//! \brief \ref error_guessing
TEST_CASE("test continue_node try_put_and_wait") {
test_try_put_and_wait();
}
#endif

0 comments on commit 324ac2d

Please sign in to comment.