diff --git a/include/oneapi/tbb/detail/_flow_graph_join_impl.h b/include/oneapi/tbb/detail/_flow_graph_join_impl.h index 90a77b6b87..bf36288655 100644 --- a/include/oneapi/tbb/detail/_flow_graph_join_impl.h +++ b/include/oneapi/tbb/detail/_flow_graph_join_impl.h @@ -656,13 +656,23 @@ const K& operator()(const table_item_type& v) { return v.my_key; } }; + template + struct key_matching_port_base { +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + using type = metainfo_hash_buffer; +#else + using type = hash_buffer; +#endif + }; + // the ports can have only one template parameter. We wrap the types needed in // a traits type template< class TraitsType > class key_matching_port : public receiver, - public hash_buffer< typename TraitsType::K, typename TraitsType::T, typename TraitsType::TtoK, - typename TraitsType::KHash > { + public key_matching_port_base< typename TraitsType::K, typename TraitsType::T, typename TraitsType::TtoK, + typename TraitsType::KHash >::type + { public: typedef TraitsType traits; typedef key_matching_port class_type; @@ -672,7 +682,7 @@ typedef typename receiver::predecessor_type predecessor_type; typedef typename TraitsType::TtoK type_to_key_func_type; typedef typename TraitsType::KHash hash_compare_type; - typedef hash_buffer< key_type, input_type, type_to_key_func_type, hash_compare_type > buffer_type; + typedef typename key_matching_port_base::type buffer_type; private: // ----------- Aggregator ------------ @@ -685,12 +695,21 @@ char type; input_type my_val; input_type *my_arg; +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + message_metainfo* metainfo = nullptr; +#endif // constructor for value parameter - key_matching_port_operation(const input_type& e, op_type t) : - type(char(t)), my_val(e), my_arg(nullptr) {} + key_matching_port_operation(const input_type& e, op_type t + __TBB_FLOW_GRAPH_METAINFO_ARG(const message_metainfo& info)) + : type(char(t)), my_val(e), my_arg(nullptr) + __TBB_FLOW_GRAPH_METAINFO_ARG(metainfo(const_cast(&info))) {} + // constructor for pointer parameter - key_matching_port_operation(const input_type* p, op_type t) : - type(char(t)), my_arg(const_cast(p)) {} + key_matching_port_operation(const input_type* p, op_type t + __TBB_FLOW_GRAPH_METAINFO_ARG(message_metainfo& info)) + : type(char(t)), my_arg(const_cast(p)) + __TBB_FLOW_GRAPH_METAINFO_ARG(metainfo(&info)) {} + // constructor with no parameter key_matching_port_operation(op_type t) : type(char(t)), my_arg(nullptr) {} }; @@ -706,18 +725,40 @@ op_list = op_list->next; switch(current->type) { case try__put: { - bool was_inserted = this->insert_with_key(current->my_val); + bool was_inserted = false; +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + if (current->metainfo) { + was_inserted = this->insert_with_key(current->my_val, *(current->metainfo)); + } else +#endif + { + was_inserted = this->insert_with_key(current->my_val); + } // return failure if a duplicate insertion occurs current->status.store( was_inserted ? SUCCEEDED : FAILED, std::memory_order_release); } break; - case get__item: + case get__item: { // use current_key from FE for item __TBB_ASSERT(current->my_arg, nullptr); - if(!this->find_with_key(my_join->current_key, *(current->my_arg))) { + bool find_result = false; +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + if (current->metainfo) { + find_result = this->find_with_key(my_join->current_key, *(current->my_arg), + *(current->metainfo)); + + } else +#endif + { + find_result = this->find_with_key(my_join->current_key, *(current->my_arg)); + } +#if TBB_USE_DEBUG + if (!find_result) { __TBB_ASSERT(false, "Failed to find item corresponding to current_key."); } +#endif current->status.store( SUCCEEDED, std::memory_order_release); + } break; case res_port: // use current_key from FE for item @@ -732,22 +773,26 @@ template< typename R, typename B > friend class run_and_put_task; template friend class broadcast_cache; template friend class round_robin_cache; - graph_task* try_put_task(const input_type& v) override { - key_matching_port_operation op_data(v, try__put); + private: + graph_task* try_put_task_impl(const input_type& v __TBB_FLOW_GRAPH_METAINFO_ARG(const message_metainfo& metainfo)) { + key_matching_port_operation op_data(v, try__put __TBB_FLOW_GRAPH_METAINFO_ARG(metainfo)); graph_task* rtask = nullptr; my_aggregator.execute(&op_data); if(op_data.status == SUCCEEDED) { - rtask = my_join->increment_key_count((*(this->get_key_func()))(v)); // may spawn + rtask = my_join->increment_key_count((*(this->get_key_func()))(v)); // may spawn // rtask has to reflect the return status of the try_put if(!rtask) rtask = SUCCESSFULLY_ENQUEUED; } return rtask; } + protected: + graph_task* try_put_task(const input_type& v) override { + return try_put_task_impl(v __TBB_FLOW_GRAPH_METAINFO_ARG(message_metainfo{})); + } #if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT - // TODO: add support for key_matching join_node - graph_task* try_put_task(const input_type& v, const message_metainfo&) override { - return try_put_task(v); + graph_task* try_put_task(const input_type& v, const message_metainfo& metainfo) override { + return try_put_task_impl(v, metainfo); } #endif @@ -786,6 +831,15 @@ return op_data.status == SUCCEEDED; } +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + bool get_item( input_type& v, message_metainfo& metainfo ) { + // aggregator uses current_key from FE for Key + key_matching_port_operation op_data(&v, get__item, metainfo); + my_aggregator.execute(&op_data); + return op_data.status == SUCCEEDED; + } +#endif + // reset_port is called when item is accepted by successor, but // is initiated by join_node. void reset_port() { @@ -1018,10 +1072,17 @@ unref_key_type my_val; output_type* my_output; graph_task* bypass_t; +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + message_metainfo* metainfo = nullptr; +#endif // constructor for value parameter key_matching_FE_operation(const unref_key_type& e , op_type t) : type(char(t)), my_val(e), my_output(nullptr), bypass_t(nullptr) {} key_matching_FE_operation(output_type *p, op_type t) : type(char(t)), my_output(p), bypass_t(nullptr) {} +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + key_matching_FE_operation(output_type *p, op_type t, message_metainfo& info) + : type(char(t)), my_output(p), bypass_t(nullptr), metainfo(&info) {} +#endif // constructor with no parameter key_matching_FE_operation(op_type t) : type(char(t)), my_output(nullptr), bypass_t(nullptr) {} }; @@ -1039,8 +1100,11 @@ bool do_fwd = this->buffer_empty() && is_graph_active(this->graph_ref); this->current_key = t; this->delete_with_key(this->current_key); // remove the key - if(join_helper::get_items(my_inputs, l_out)) { // <== call back - this->push_back(l_out); +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + message_metainfo metainfo; +#endif + if(join_helper::get_items(my_inputs, l_out __TBB_FLOW_GRAPH_METAINFO_ARG(metainfo))) { // <== call back + this->push_back(l_out __TBB_FLOW_GRAPH_METAINFO_ARG(metainfo)); if(do_fwd) { // we enqueue if receiving an item from predecessor, not if successor asks for item d1::small_object_allocator allocator{}; typedef forward_task_bypass task_type; @@ -1094,6 +1158,9 @@ } else { *(current->my_output) = this->front(); +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + *(current->metainfo) = this->front_metainfo(); +#endif current->status.store( SUCCEEDED, std::memory_order_release); } break; @@ -1168,8 +1235,10 @@ } #if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT - bool try_to_make_tuple(output_type &out, message_metainfo&) { - return try_to_make_tuple(out); + bool try_to_make_tuple(output_type &out, message_metainfo& metainfo) { + key_matching_FE_operation op_data(&out, try_make, metainfo); + my_aggregator.execute(&op_data); + return op_data.status == SUCCEEDED; } #endif diff --git a/include/oneapi/tbb/detail/_flow_graph_tagged_buffer_impl.h b/include/oneapi/tbb/detail/_flow_graph_tagged_buffer_impl.h index 0d9de17654..fe5bb6073a 100644 --- a/include/oneapi/tbb/detail/_flow_graph_tagged_buffer_impl.h +++ b/include/oneapi/tbb/detail/_flow_graph_tagged_buffer_impl.h @@ -30,32 +30,92 @@ // elements in the table are a simple list; we need pointer to next element to // traverse the chain -template -struct buffer_element_type { - // the second parameter below is void * because we can't forward-declare the type - // itself, so we just reinterpret_cast below. - typedef typename aligned_pair::type type; + +template +struct hash_buffer_element : public aligned_pair { + using key_type = Key; + using value_type = ValueType; + + value_type* get_value_ptr() { return reinterpret_cast(this->first); } + hash_buffer_element* get_next() { return reinterpret_cast(this->second); } + void set_next(hash_buffer_element* new_next) { this->second = reinterpret_cast(new_next); } + + void create_element(const value_type& v) { + ::new(this->first) value_type(v); + } + + void create_element(hash_buffer_element&& other) { + ::new(this->first) value_type(std::move(*other.get_value_ptr())); + } + + void destroy_element() { + get_value_ptr()->~value_type(); + } +}; + +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT +template +struct metainfo_hash_buffer_element : public aligned_three { + using key_type = Key; + using value_type = ValueType; + + value_type* get_value_ptr() { return reinterpret_cast(this->first); } + metainfo_hash_buffer_element* get_next() { + return reinterpret_cast(this->second); + } + void set_next(metainfo_hash_buffer_element* new_next) { this->second = reinterpret_cast(new_next); } + message_metainfo& get_metainfo() { return this->third; } + + void create_element(const value_type& v) { + ::new(this->first) value_type(v); + } + + void create_element(const value_type& v, const message_metainfo& metainfo) { + __TBB_ASSERT(this->third.empty(), nullptr); + ::new(this->first) value_type(v); + this->third = metainfo; + + for (auto waiter : metainfo.waiters()) { + waiter->reserve(); + } + } + + void create_element(metainfo_hash_buffer_element&& other) { + __TBB_ASSERT(this->third.empty(), nullptr); + ::new(this->first) value_type(std::move(*other.get_value_ptr())); + this->third = std::move(other.get_metainfo()); + } + + void destroy_element() { + get_value_ptr()->~value_type(); + + for (auto waiter : get_metainfo().waiters()) { + waiter->release(); + } + get_metainfo() = message_metainfo{}; + } }; +#endif template < - typename Key, // type of key within ValueType - typename ValueType, + typename ElementType, typename ValueToKey, // abstract method that returns "const Key" or "const Key&" given ValueType typename HashCompare, // has hash and equal - typename Allocator=tbb::cache_aligned_allocator< typename aligned_pair::type > + typename Allocator=tbb::cache_aligned_allocator > -class hash_buffer : public HashCompare { +class hash_buffer_impl : public HashCompare { public: static const size_t INITIAL_SIZE = 8; // initial size of the hash pointer table - typedef ValueType value_type; - typedef typename buffer_element_type< value_type >::type element_type; + typedef typename ElementType::key_type key_type; + typedef typename ElementType::value_type value_type; + typedef ElementType element_type; typedef value_type *pointer_type; typedef element_type *list_array_type; // array we manage manually typedef list_array_type *pointer_array_type; typedef typename std::allocator_traits::template rebind_alloc pointer_array_allocator_type; typedef typename std::allocator_traits::template rebind_alloc elements_array_allocator; - typedef typename std::decay::type Knoref; + typedef typename std::decay::type Knoref; private: ValueToKey *my_key; @@ -69,9 +129,9 @@ class hash_buffer : public HashCompare { void set_up_free_list( element_type **p_free_list, list_array_type la, size_t sz) { for(size_t i=0; i < sz - 1; ++i ) { // construct free list - la[i].second = &(la[i+1]); + la[i].set_next(&(la[i + 1])); } - la[sz-1].second = nullptr; + la[sz - 1].set_next(nullptr); *p_free_list = (element_type *)&(la[0]); } @@ -101,15 +161,18 @@ class hash_buffer : public HashCompare { { DoCleanup my_cleanup(new_pointer_array, new_elements_array, new_size); new_elements_array = elements_array_allocator().allocate(my_size); +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + for (std::size_t i = 0; i < my_size; ++i) { + ::new(new_elements_array + i) element_type(); + } +#endif new_pointer_array = pointer_array_allocator_type().allocate(new_size); for(size_t i=0; i < new_size; ++i) new_pointer_array[i] = nullptr; set_up_free_list(&new_free_list, new_elements_array, my_size ); for(size_t i=0; i < my_size; ++i) { - for( element_type* op = pointer_array[i]; op; op = (element_type *)(op->second)) { - value_type *ov = reinterpret_cast(&(op->first)); - // could have std::move semantics - internal_insert_with_key(new_pointer_array, new_size, new_free_list, *ov); + for( element_type* op = pointer_array[i]; op; op = (element_type *)(op->get_next())) { + internal_insert_with_key(new_pointer_array, new_size, new_free_list, std::move(*op)); } } my_cleanup.my_pa = nullptr; @@ -126,15 +189,26 @@ class hash_buffer : public HashCompare { // v should have perfect forwarding if std::move implemented. // we use this method to move elements in grow_array, so can't use class fields + template + const value_type& get_value_from_pack(const Value& value, const Args&...) { + return value; + } + + template + const value_type& get_value_from_pack(Element&& element) { + return *(element.get_value_ptr()); + } + + template void internal_insert_with_key( element_type **p_pointer_array, size_t p_sz, list_array_type &p_free_list, - const value_type &v) { + Args&&... args) { size_t l_mask = p_sz-1; __TBB_ASSERT(my_key, "Error: value-to-key functor not provided"); - size_t h = this->hash(tbb::detail::invoke(*my_key, v)) & l_mask; + size_t h = this->hash(tbb::detail::invoke(*my_key, get_value_from_pack(args...))) & l_mask; __TBB_ASSERT(p_free_list, "Error: free list not set up."); - element_type* my_elem = p_free_list; p_free_list = (element_type *)(p_free_list->second); - (void) new(&(my_elem->first)) value_type(v); - my_elem->second = p_pointer_array[h]; + element_type* my_elem = p_free_list; p_free_list = (element_type *)(p_free_list->get_next()); + my_elem->create_element(std::forward(args)...); + my_elem->set_next(p_pointer_array[h]); p_pointer_array[h] = my_elem; } @@ -142,6 +216,11 @@ class hash_buffer : public HashCompare { pointer_array = pointer_array_allocator_type().allocate(my_size); for(size_t i = 0; i < my_size; ++i) pointer_array[i] = nullptr; elements_array = elements_array_allocator().allocate(my_size / 2); +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + for (std::size_t i = 0; i < my_size / 2; ++i) { + ::new(elements_array + i) element_type(); + } +#endif set_up_free_list(&free_list, elements_array, my_size / 2); } @@ -151,13 +230,8 @@ class hash_buffer : public HashCompare { for(size_t i = 0; i < sz; ++i ) { element_type *p_next; for( element_type *p = pa[i]; p; p = p_next) { - p_next = (element_type *)p->second; - // TODO revamp: make sure type casting is correct. - void* ptr = (void*)(p->first); -#if _MSC_VER && _MSC_VER <= 1900 && !__INTEL_COMPILER - suppress_unused_warning(ptr); -#endif - ((value_type*)ptr)->~value_type(); + p_next = p->get_next(); + p->destroy_element(); } } pointer_array_allocator_type().deallocate(pa, sz); @@ -166,6 +240,11 @@ class hash_buffer : public HashCompare { // Separate test (if allocation of pa throws, el may be allocated. // but no elements will be constructed.) if(el) { +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + for (std::size_t i = 0; i < sz / 2; ++i) { + (el + i)->~element_type(); + } +#endif elements_array_allocator().deallocate(el, sz / 2); el = nullptr; } @@ -174,17 +253,17 @@ class hash_buffer : public HashCompare { } public: - hash_buffer() : my_key(nullptr), my_size(INITIAL_SIZE), nelements(0) { + hash_buffer_impl() : my_key(nullptr), my_size(INITIAL_SIZE), nelements(0) { internal_initialize_buffer(); } - ~hash_buffer() { + ~hash_buffer_impl() { internal_free_buffer(pointer_array, elements_array, my_size, nelements); delete my_key; my_key = nullptr; } - hash_buffer(const hash_buffer&) = delete; - hash_buffer& operator=(const hash_buffer&) = delete; + hash_buffer_impl(const hash_buffer_impl&) = delete; + hash_buffer_impl& operator=(const hash_buffer_impl&) = delete; void reset() { internal_free_buffer(pointer_array, elements_array, my_size, nelements); @@ -197,34 +276,41 @@ class hash_buffer : public HashCompare { // pointer is used to clone() ValueToKey* get_key_func() { return my_key; } - bool insert_with_key(const value_type &v) { - pointer_type p = nullptr; + template + bool insert_with_key(const value_type &v, Args&&... args) { + element_type* p = nullptr; __TBB_ASSERT(my_key, "Error: value-to-key functor not provided"); - if(find_ref_with_key(tbb::detail::invoke(*my_key, v), p)) { - p->~value_type(); - (void) new(p) value_type(v); // copy-construct into the space + if(find_element_ref_with_key(tbb::detail::invoke(*my_key, v), p)) { + p->destroy_element(); + p->create_element(v, std::forward(args)...); return false; } ++nelements; if(nelements*2 > my_size) grow_array(); - internal_insert_with_key(pointer_array, my_size, free_list, v); + internal_insert_with_key(pointer_array, my_size, free_list, v, std::forward(args)...); return true; } - // returns true and sets v to array element if found, else returns false. - bool find_ref_with_key(const Knoref& k, pointer_type &v) { + bool find_element_ref_with_key(const Knoref& k, element_type*& v) { size_t i = this->hash(k) & mask(); - for(element_type* p = pointer_array[i]; p; p = (element_type *)(p->second)) { - pointer_type pv = reinterpret_cast(&(p->first)); + for(element_type* p = pointer_array[i]; p; p = (element_type *)(p->get_next())) { __TBB_ASSERT(my_key, "Error: value-to-key functor not provided"); - if(this->equal(tbb::detail::invoke(*my_key, *pv), k)) { - v = pv; + if(this->equal(tbb::detail::invoke(*my_key, *p->get_value_ptr()), k)) { + v = p; return true; } } return false; } + // returns true and sets v to array element if found, else returns false. + bool find_ref_with_key(const Knoref& k, pointer_type &v) { + element_type* element_ptr = nullptr; + bool res = find_element_ref_with_key(k, element_ptr); + v = element_ptr->get_value_ptr(); + return res; + } + bool find_with_key( const Knoref& k, value_type &v) { value_type *p; if(find_ref_with_key(k, p)) { @@ -238,14 +324,14 @@ class hash_buffer : public HashCompare { void delete_with_key(const Knoref& k) { size_t h = this->hash(k) & mask(); element_type* prev = nullptr; - for(element_type* p = pointer_array[h]; p; prev = p, p = (element_type *)(p->second)) { - value_type *vp = reinterpret_cast(&(p->first)); + for(element_type* p = pointer_array[h]; p; prev = p, p = (element_type *)(p->get_next())) { + value_type *vp = p->get_value_ptr(); __TBB_ASSERT(my_key, "Error: value-to-key functor not provided"); if(this->equal(tbb::detail::invoke(*my_key, *vp), k)) { - vp->~value_type(); - if(prev) prev->second = p->second; - else pointer_array[h] = (element_type *)(p->second); - p->second = free_list; + p->destroy_element(); + if(prev) prev->set_next(p->get_next()); + else pointer_array[h] = (element_type *)(p->get_next()); + p->set_next(free_list); free_list = p; --nelements; return; @@ -254,4 +340,51 @@ class hash_buffer : public HashCompare { __TBB_ASSERT(false, "key not found for delete"); } }; + +template + < + typename Key, // type of key within ValueType + typename ValueType, + typename ValueToKey, // abstract method that returns "const Key" or "const Key&" given ValueType + typename HashCompare, // has hash and equal + typename Allocator=tbb::cache_aligned_allocator> + > +using hash_buffer = hash_buffer_impl, + ValueToKey, HashCompare, Allocator>; + +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT +template + < + typename Key, // type of key within ValueType + typename ValueType, + typename ValueToKey, // abstract method that returns "const Key" or "const Key&" given ValueType + typename HashCompare, // has hash and equal + typename Allocator=tbb::cache_aligned_allocator> + > +struct metainfo_hash_buffer : public hash_buffer_impl, + ValueToKey, HashCompare, Allocator> +{ +private: + using base_type = hash_buffer_impl, + ValueToKey, HashCompare, Allocator>; +public: + bool find_with_key(const typename base_type::Knoref& k, + typename base_type::value_type& v) + { + return this->find_with_key(k, v); + } + + bool find_with_key(const typename base_type::Knoref& k, + typename base_type::value_type& v, message_metainfo& metainfo) + { + typename base_type::element_type* p = nullptr; + bool result = this->find_element_ref_with_key(k, p); + if (result) { + v = *(p->get_value_ptr()); + metainfo = p->get_metainfo(); + } + return result; + } +}; +#endif // __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT #endif // __TBB__flow_graph_hash_buffer_impl_H diff --git a/include/oneapi/tbb/detail/_flow_graph_types_impl.h b/include/oneapi/tbb/detail/_flow_graph_types_impl.h index f15cdefa3e..0ee8d2f309 100644 --- a/include/oneapi/tbb/detail/_flow_graph_types_impl.h +++ b/include/oneapi/tbb/detail/_flow_graph_types_impl.h @@ -75,38 +75,45 @@ struct make_sequence < 0, S... > { //! type mimicking std::pair but with trailing fill to ensure each element of an array //* will have the correct alignment -template -struct type_plus_align { - char first[sizeof(T1)]; - T2 second; - char fill1[REM]; +template struct alignment_of { + typedef struct { char t; U padded; } test_alignment; + static const size_t value = sizeof(test_alignment) - sizeof(U); }; -template -struct type_plus_align { - char first[sizeof(T1)]; - T2 second; +template +struct max_alignment_helper; + +template +struct max_alignment_helper { + using type = typename max_alignment_helper::type>::type; }; -template struct alignment_of { - typedef struct { char t; U padded; } test_alignment; - static const size_t value = sizeof(test_alignment) - sizeof(U); +template +struct max_alignment_helper { + using type = typename std::conditional::type; }; +template +using max_alignment_helper_t = typename max_alignment_helper::type; + // T1, T2 are actual types stored. The space defined for T1 in the type returned // is a char array of the correct size. Type T2 should be trivially-constructible, // T1 must be explicitly managed. -template -struct aligned_pair { - static const size_t t1_align = alignment_of::value; - static const size_t t2_align = alignment_of::value; - typedef type_plus_align just_pair; - static const size_t max_align = t1_align < t2_align ? t2_align : t1_align; - static const size_t extra_bytes = sizeof(just_pair) % max_align; - static const size_t remainder = extra_bytes ? max_align - extra_bytes : 0; -public: - typedef type_plus_align type; -}; // aligned_pair + +template +struct alignas(alignof(max_alignment_helper_t)) aligned_pair { + char first[sizeof(T1)]; + T2 second; +}; + +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT +template +struct alignas(alignof(max_alignment_helper_t)) aligned_three { + char first[sizeof(T1)]; + T2 second; + T3 third; +}; +#endif // support for variant type // type we use when we're not storing a value diff --git a/test/tbb/test_join_node_preview.cpp b/test/tbb/test_join_node_preview.cpp index 7adb41b6e3..5799f2ac7e 100644 --- a/test/tbb/test_join_node_preview.cpp +++ b/test/tbb/test_join_node_preview.cpp @@ -229,6 +229,104 @@ void test_try_put_and_wait_reserving() { }); } +struct int_wrapper { + int i = 0; + int_wrapper() : i(0) {} + int_wrapper(int ii) : i(ii) {} + int_wrapper& operator=(int ii) { + i = ii; + return *this; + } + + int key() const { + return i; + } + + bool operator==(const int_wrapper& rhs) { + return i == rhs.i; + } +}; + +template +void test_try_put_and_wait_key_matching(Body... body) { + // Body of one argument for testing standard key_matching + // Body of zero arguments for testing message based key_matching + static_assert(sizeof...(Body) == 0 || sizeof...(Body) == 1); + tbb::task_arena arena(1); + + arena.execute([=] { + tbb::flow::graph g; + + std::vector start_work_items; + std::vector processed_items; + std::vector new_work_items; + int_wrapper wait_message = 10; + + for (int i = 0; i < wait_message.i; ++i) { + start_work_items.emplace_back(i); + if (i != 0) { + new_work_items.emplace_back(i + 10); + } + } + + using tuple_type = std::tuple; + tbb::flow::join_node> join(g, body..., body..., body...); + + tbb::flow::function_node function(g, tbb::flow::serial, + [&](tuple_type tuple) noexcept { + CHECK(std::get<0>(tuple) == std::get<1>(tuple)); + CHECK(std::get<1>(tuple) == std::get<2>(tuple)); + + auto input = std::get<0>(tuple); + + if (input == wait_message) { + for (auto item : new_work_items) { + tbb::flow::input_port<0>(join).try_put(item); + tbb::flow::input_port<1>(join).try_put(item); + tbb::flow::input_port<2>(join).try_put(item); + } + } + processed_items.emplace_back(input); + return 0; + }); + + tbb::flow::make_edge(join, function); + + tbb::flow::input_port<0>(join).try_put(wait_message); + tbb::flow::input_port<1>(join).try_put(wait_message); + + // For the first port - submit items in reversed order + for (std::size_t i = start_work_items.size(); i != 0; --i) { + tbb::flow::input_port<0>(join).try_put(start_work_items[i - 1]); + } + + // For first two ports - submit items in direct order + for (auto item : start_work_items) { + tbb::flow::input_port<1>(join).try_put(item); + tbb::flow::input_port<2>(join).try_put(item); + } + + tbb::flow::input_port<2>(join).try_put_and_wait(wait_message); + + // It is expected that the join_node would push the tuple of three copies of first element in start_work_items + // And occupy the concurrency of function. Other tuples would be rejected and taken using push-pull protocol + // in order of submission + std::size_t check_index = 0; + + for (auto item : start_work_items) { + CHECK_MESSAGE(processed_items[check_index++] == item, "Unexpected start_work_items processing"); + } + + CHECK_MESSAGE(processed_items[check_index++] == wait_message, "Unexpected wait_message processing"); + + g.wait_for_all(); + + for (auto item : new_work_items) { + CHECK_MESSAGE(processed_items[check_index++] == item, "Unexpected start_work_items processing"); + } + }); +} + //! Test follows and precedes API //! \brief \ref error_guessing TEST_CASE("Test follows and precedes API"){ @@ -252,5 +350,8 @@ TEST_CASE("Test removal of the predecessor while having none") { TEST_CASE("Test join_node try_put_and_wait") { test_try_put_and_wait_queueing(); test_try_put_and_wait_reserving(); - // TODO: add tests for key_matching, tag_matching and msg based key_matching + // Test standard key_matching policy + test_try_put_and_wait_key_matching([](int_wrapper w) { return w.i; }); + // Test msg based key_matching policy + test_try_put_and_wait_key_matching(); }