Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[try_put_and_wait] Part 1: Add implementation of try_put_and_wait feature for function_node #1398

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions include/oneapi/tbb/detail/_flow_graph_body_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,23 @@ class apply_body_task_bypass
NodeType &my_node;
Input my_input;

graph_task* call_apply_body_bypass_impl(std::true_type) {
using check_metainfo = std::is_same<BaseTaskType, graph_task>;
using not_contains_metainfo = std::true_type;
using contains_metainfo = std::false_type;

graph_task* call_apply_body_bypass_impl(not_contains_metainfo) {
kboyarinov marked this conversation as resolved.
Show resolved Hide resolved
return my_node.apply_body_bypass(my_input
__TBB_FLOW_GRAPH_METAINFO_ARG(message_metainfo{}));
}

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
graph_task* call_apply_body_bypass_impl(std::false_type) {
graph_task* call_apply_body_bypass_impl(contains_metainfo) {
return my_node.apply_body_bypass(my_input, message_metainfo{this->get_msg_wait_context_vertices()});
}
#endif

graph_task* call_apply_body_bypass() {
return call_apply_body_bypass_impl(std::is_same<BaseTaskType, graph_task>{});
return call_apply_body_bypass_impl(check_metainfo{});
}

public:
Expand Down
5 changes: 2 additions & 3 deletions include/oneapi/tbb/detail/_flow_graph_cache_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class successor_cache : no_copy {

virtual graph_task* try_put_task( const T& t ) = 0;
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
virtual graph_task* try_put_task( const T& t, const message_metainfo& metainfo) = 0;
virtual graph_task* try_put_task( const T& t, const message_metainfo& metainfo ) = 0;
#endif
}; // successor_cache<T>

Expand Down Expand Up @@ -370,7 +370,6 @@ class broadcast_cache : public successor_cache<T, M> {
// Do not work with the passed pointer here as it may not be fully initialized yet
}

// as above, but call try_put_task instead, and return the last task we received (if any)
graph_task* try_put_task( const T &t ) override {
return try_put_task_impl(t __TBB_FLOW_GRAPH_METAINFO_ARG(message_metainfo{}));
}
Expand Down Expand Up @@ -447,7 +446,7 @@ class round_robin_cache : public successor_cache<T, M> {

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
// TODO: add support for round robin cache
graph_task* try_put_task(const T& t, const message_metainfo&) override {
graph_task* try_put_task( const T& t, const message_metainfo& ) override {
return try_put_task(t);
}
#endif
Expand Down
11 changes: 9 additions & 2 deletions include/oneapi/tbb/detail/_flow_graph_item_buffer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,15 @@ class item_buffer {
#endif

// following methods are for reservation of the front of a buffer.
void reserve_item(size_type i) { __TBB_ASSERT(my_item_valid(i) && !my_item_reserved(i), "item cannot be reserved"); element(i).state = reserved_item; }
void release_item(size_type i) { __TBB_ASSERT(my_item_reserved(i), "item is not reserved"); element(i).state = has_item; }
void reserve_item(size_type i) {
__TBB_ASSERT(my_item_valid(i) && !my_item_reserved(i), "item cannot be reserved");
element(i).state = reserved_item;
}

void release_item(size_type i) {
__TBB_ASSERT(my_item_reserved(i), "item is not reserved");
element(i).state = has_item;
}

void destroy_front() { destroy_item(my_head); ++my_head; }
void destroy_back() { destroy_item(my_tail-1); --my_tail; }
Expand Down
2 changes: 1 addition & 1 deletion include/oneapi/tbb/detail/_flow_graph_node_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class function_input_base : public receiver<Input>, no_assign {

friend class apply_body_task_bypass< class_type, input_type >;
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
friend class apply_body_task_bypass< class_type, input_type, graph_task_with_message_waiters>;
friend class apply_body_task_bypass< class_type, input_type, graph_task_with_message_waiters >;
#endif
friend class forward_task_bypass< class_type >;

Expand Down
6 changes: 3 additions & 3 deletions include/oneapi/tbb/flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,12 @@ class receiver {
//! Put an item to the receiver and wait for completion
bool try_put_and_wait( const T& t ) {
// Since try_put_and_wait is a blocking call, it is safe to create wait_context on stack
d1::wait_context_vertex msg_wait_context{};
d1::wait_context_vertex msg_wait_vertex{};

bool res = internal_try_put(t, message_metainfo{message_metainfo::waiters_type{&msg_wait_context}});
bool res = internal_try_put(t, message_metainfo{message_metainfo::waiters_type{&msg_wait_vertex}});
if (res) {
__TBB_ASSERT(graph_reference().my_context != nullptr, "No wait_context associated with the Flow Graph");
wait(msg_wait_context.get_context(), *graph_reference().my_context);
wait(msg_wait_vertex.get_context(), *graph_reference().my_context);
}
return res;
}
Expand Down
4 changes: 2 additions & 2 deletions test/common/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ struct harness_counting_receiver : public tbb::flow::receiver<T> {
}

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
tbb::detail::d2::graph_task *try_put_task( const T &t, const tbb::detail::d2::message_metainfo&) override {
tbb::detail::d2::graph_task *try_put_task( const T &t, const tbb::detail::d2::message_metainfo& ) override {
return try_put_task(t);
}
#endif
Expand Down Expand Up @@ -339,7 +339,7 @@ struct harness_mapped_receiver : public tbb::flow::receiver<T> {
}

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
tbb::detail::d2::graph_task *try_put_task( const T &t, const tbb::detail::d2::message_metainfo&) override {
tbb::detail::d2::graph_task *try_put_task( const T &t, const tbb::detail::d2::message_metainfo& ) override {
return try_put_task(t);
}
#endif
Expand Down
29 changes: 20 additions & 9 deletions test/tbb/test_function_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,13 @@ void test_follows_and_precedes_api() {
#endif

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
// Basic idea of the following tests is to check that try_put_and_wait(message) call for function_node
// with one of the policies (lightweight, queueing and rejecting) with different concurrency limits
// processes all of the previous jobs required to process message, the message itself, but does
// not process the elements submitted later or not required to process the message
// These tests submits start_work_items using the regular try_put and then submits wait_message
kboyarinov marked this conversation as resolved.
Show resolved Hide resolved
// with try_put_and_wait. During the completion of the graph, new_work_items would be submitted
// once the wait_message arrives.
void test_try_put_and_wait_lightweight(std::size_t concurrency_limit) {
tbb::task_arena arena(1);

Expand All @@ -484,9 +491,7 @@ void test_try_put_and_wait_lightweight(std::size_t concurrency_limit) {

for (int i = 0; i < wait_message; ++i) {
start_work_items.emplace_back(i);
if (i != 0) {
new_work_items.emplace_back(i + 10);
}
new_work_items.emplace_back(i + 1 + wait_message);
}

using function_node_type = tbb::flow::function_node<int, int, tbb::flow::lightweight>;
Expand Down Expand Up @@ -529,10 +534,13 @@ void test_try_put_and_wait_lightweight(std::size_t concurrency_limit) {
if (concurrency_limit == tbb::flow::serial) {
// If the lightweight function_node is serial, it should process the wait_message but add items from new_work_items
// into the queue since the concurrency limit is occupied.
CHECK_MESSAGE(processed_items.size() == start_work_items.size() + 1, "Unexpected number of elements processed");
CHECK_MESSAGE(processed_items[check_index++] == wait_message, "Unexpected items processing");
} else {
// If the node is unlimited, it should process new_work_items immediately while processing the wait_message
// Hence they should be processed before exiting the try_put_and_wait
CHECK_MESSAGE(processed_items.size() == start_work_items.size() + new_work_items.size() + 1,
"Unexpected number of elements processed");
for (auto item : new_work_items) {
CHECK_MESSAGE(processed_items[check_index++] == item, "Unexpected items processing");
}
Expand All @@ -549,6 +557,7 @@ void test_try_put_and_wait_lightweight(std::size_t concurrency_limit) {
CHECK_MESSAGE(processed_items[check_index++] == item, "Unexpected items processing");
}
}
CHECK(check_index == processed_items.size());
});
}

Expand All @@ -565,9 +574,7 @@ void test_try_put_and_wait_queueing(std::size_t concurrency_limit) {

for (int i = 0; i < wait_message; ++i) {
start_work_items.emplace_back(i);
if (i != 0) {
new_work_items.emplace_back(i + 10);
}
new_work_items.emplace_back(i + 1 + wait_message);
}

using function_node_type = tbb::flow::function_node<int, int, tbb::flow::queueing>;
Expand Down Expand Up @@ -605,9 +612,12 @@ void test_try_put_and_wait_queueing(std::size_t concurrency_limit) {
// Serial queueing function_node should add all start_work_items except the first one into the queue
// and then process them in FIFO order.
// wait_message would also be added to the queue, but would be processed later
CHECK_MESSAGE(processed_items.size() == start_work_items.size() + 1, "Unexpected number of elements processed");
for (auto item : start_work_items) {
CHECK_MESSAGE(processed_items[check_index++] == item, "Unexpected items processing");
}
} else {
CHECK_MESSAGE(processed_items.size() == 1, "Unexpected number of elements processed");
}

// For the unlimited function_node, all of the tasks for start_work_items and wait_message would be spawned
Expand All @@ -634,6 +644,7 @@ void test_try_put_and_wait_queueing(std::size_t concurrency_limit) {
CHECK_MESSAGE(processed_items[check_index++] == start_work_items[i - 1], "Unexpected items processing");
}
}
CHECK(check_index == processed_items.size());
});
}

Expand Down Expand Up @@ -678,7 +689,7 @@ void test_try_put_and_wait_rejecting(size_t concurrency_limit) {
// If the first action is try_put_and_wait, it will occupy concurrency of the function_node
// All submits of new_work_items inside of the body should be rejected
bool result = function.try_put_and_wait(wait_message);
CHECK_MESSAGE(result, "task should not rejected since the node concurrency is not acquired");
CHECK_MESSAGE(result, "task should not rejected since the node concurrency is not saturated");

CHECK_MESSAGE(processed_items.size() == 1, nullptr);
CHECK_MESSAGE(processed_items[0] == wait_message, "Unexpected items processing");
Expand All @@ -690,10 +701,10 @@ void test_try_put_and_wait_rejecting(size_t concurrency_limit) {
processed_items.clear();

// If the first action is try_put, try_put_and_wait is expected to return false since the concurrency of the
// node would be acquired
// node would be saturated
function.try_put(0);
result = function.try_put_and_wait(wait_message);
CHECK_MESSAGE(!result, "task should be rejected since the node concurrency is acquired");
CHECK_MESSAGE(!result, "task should be rejected since the node concurrency is saturated");
CHECK(processed_items.empty());

g.wait_for_all();
Expand Down
Loading