diff --git a/CMakeLists.txt b/CMakeLists.txt index 68a682056..f7b2ef7f7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -95,11 +95,11 @@ endif () # Enable C++11 support, set compilation flags #---------------------------------------------------------------------------- if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x -Wall -Wclobbered -Wempty-body -Wignored-qualifiers -Wmissing-field-initializers -Wsign-compare -Wtype-limits -Wuninitialized -Wunused-parameter -Wunused-but-set-parameter -Wno-comment") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x -Wall -Wclobbered -Wempty-body -Wignored-qualifiers -Wmissing-field-initializers -Wsign-compare -Wtype-limits -Wuninitialized -Wunused-parameter -Wunused-but-set-parameter -Wno-comment -Wno-type-limits") endif () if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x -Wall -Wempty-body -Wignored-qualifiers -Wmissing-field-initializers -Wsign-compare -Wtype-limits -Wuninitialized -Wunused-parameter -Wno-comment") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x -Wall -Wempty-body -Wignored-qualifiers -Wmissing-field-initializers -Wsign-compare -Wtype-limits -Wuninitialized -Wunused-parameter -Wno-comment -Wno-tautological-compare") option(USE_LIBCPP "Use libc++ with Clang" OFF) if (USE_LIBCPP) diff --git a/tests/sort.cpp b/tests/sort.cpp index d32574fa0..23473c6f7 100644 --- a/tests/sort.cpp +++ b/tests/sort.cpp @@ -1,11 +1,25 @@ #define BOOST_TEST_MODULE Sort #include +#include #include #include #include #include "context_setup.hpp" -BOOST_AUTO_TEST_CASE(sort_pairs) +BOOST_AUTO_TEST_CASE(sort_keys) +{ + const size_t n = 1000 * 1000; + + std::vector k = random_vector(n); + vex::vector keys(ctx, k); + + vex::sort(keys); + vex::copy(keys, k); + + BOOST_CHECK( std::is_sorted(k.begin(), k.end()) ); +} + +BOOST_AUTO_TEST_CASE(sort_keys_vals) { const size_t n = 1000 * 1000; @@ -19,13 +33,16 @@ BOOST_AUTO_TEST_CASE(sort_pairs) BOOST_CHECK( std::is_sorted(k.begin(), k.end()) ); struct { + typedef bool result_type; + VEX_FUNCTION(device, bool(int, int), "char bit1 = 1 & prm1;\n" "char bit2 = 1 & prm2;\n" "if (bit1 == bit2) return prm1 < prm2;\n" "return bit1 < bit2;\n" ); - bool operator()(int a, int b) const { + + result_type operator()(int a, int b) const { char bit1 = 1 & a; char bit2 = 1 & b; if (bit1 == bit2) return a < b; @@ -39,17 +56,78 @@ BOOST_AUTO_TEST_CASE(sort_pairs) BOOST_CHECK(std::is_sorted(k.begin(), k.end(), even_first)); } -BOOST_AUTO_TEST_CASE(sort_keys) +BOOST_AUTO_TEST_CASE(sort_keys_tuple) { const size_t n = 1000 * 1000; - std::vector k = random_vector(n); - vex::vector keys(ctx, k); + std::vector k1 = random_vector (n); + std::vector k2 = random_vector(n); - vex::sort(keys); - vex::copy(keys, k); + vex::vector keys1(ctx, k1); + vex::vector keys2(ctx, k2); - BOOST_CHECK( std::is_sorted(k.begin(), k.end()) ); + struct { + typedef bool result_type; + + VEX_FUNCTION(device, bool(int, float, int, float), + "return (prm1 == prm3) ? (prm2 < prm4) : (prm1 < prm3);" + ); + + result_type operator()(int a1, float a2, int b1, float b2) const { + return (a1 == b1) ? (a2 < b2) : (a1 < b1); + } + } less; + + vex::sort(std::tie(keys1, keys2), less ); + vex::copy(keys1, k1); + vex::copy(keys2, k2); + + BOOST_CHECK( std::is_sorted( + boost::counting_iterator(0), + boost::counting_iterator(n), + [&](size_t i, size_t j) { + return std::make_tuple(k1[i], k2[i]) < std::make_tuple(k1[j], k2[j]); + } ) ); +} + +BOOST_AUTO_TEST_CASE(sort_keys_vals_tuple) +{ + const size_t n = 1000 * 1000; + + std::vector k1 = random_vector (n); + std::vector k2 = random_vector(n); + std::vector v1 = random_vector (n); + std::vector v2 = random_vector(n); + + vex::vector keys1(ctx, k1); + vex::vector keys2(ctx, k2); + vex::vector vals1(ctx, v1); + vex::vector vals2(ctx, v2); + + struct { + typedef bool result_type; + + VEX_FUNCTION(device, bool(int, float, int, float), + "return (prm1 == prm3) ? (prm2 < prm4) : (prm1 < prm3);" + ); + + result_type operator()(int a1, float a2, int b1, float b2) const { + return (a1 == b1) ? (a2 < b2) : (a1 < b1); + } + } less; + + vex::sort_by_key(std::tie(keys1, keys2), std::tie(vals1, vals2), less ); + + vex::copy(keys1, k1); + vex::copy(keys2, k2); + + BOOST_CHECK( std::is_sorted( + boost::counting_iterator(0), + boost::counting_iterator(n), + [&](size_t i, size_t j) { + return std::make_tuple(k1[i], k2[i]) < std::make_tuple(k1[j], k2[j]); + } ) ); } + BOOST_AUTO_TEST_SUITE_END() diff --git a/vexcl/backend/cuda/device_vector.hpp b/vexcl/backend/cuda/device_vector.hpp index e6a475fcc..9e632d1ca 100644 --- a/vexcl/backend/cuda/device_vector.hpp +++ b/vexcl/backend/cuda/device_vector.hpp @@ -67,6 +67,7 @@ struct deleter_impl { template class device_vector { public: + typedef T value_type; typedef CUdeviceptr raw_type; /// Empty constructor. diff --git a/vexcl/backend/cuda/kernel.hpp b/vexcl/backend/cuda/kernel.hpp index d66478c3c..67b5e1c2c 100644 --- a/vexcl/backend/cuda/kernel.hpp +++ b/vexcl/backend/cuda/kernel.hpp @@ -79,7 +79,7 @@ class kernel { /// Adds an argument to the kernel. template - void push_arg(Arg &&arg) { + void push_arg(const Arg &arg) { char *c = (char*)&arg; prm_pos.push_back(stack.size()); stack.insert(stack.end(), c, c + sizeof(arg)); @@ -87,7 +87,7 @@ class kernel { /// Adds an argument to the kernel. template - void push_arg(device_vector &&arg) { + void push_arg(const device_vector &arg) { push_arg(arg.raw()); } diff --git a/vexcl/backend/opencl/device_vector.hpp b/vexcl/backend/opencl/device_vector.hpp index 22749dcbb..2dd64a76d 100644 --- a/vexcl/backend/opencl/device_vector.hpp +++ b/vexcl/backend/opencl/device_vector.hpp @@ -49,6 +49,7 @@ static const mem_flags MEM_READ_WRITE = CL_MEM_READ_WRITE; template class device_vector { public: + typedef T value_type; typedef cl_mem raw_type; device_vector() {} diff --git a/vexcl/sort.hpp b/vexcl/sort.hpp index 222d37537..3a8096462 100644 --- a/vexcl/sort.hpp +++ b/vexcl/sort.hpp @@ -64,6 +64,28 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + #include #include #include @@ -71,6 +93,56 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. namespace vex { namespace detail { +// Transform tuple of vex::vectors into mpl::vector of value_types +template +struct extract_value_types { + typedef typename std::decay::type T; + + template + struct loop; + + template + struct loop::type> { + typedef typename boost::mpl::push_front< + typename loop::type, + typename std::decay< + typename boost::fusion::result_of::at_c::type + >::type::value_type + >::type type; + }; + + template + struct loop::type> { + typedef boost::mpl::vector< + typename std::decay< + typename boost::fusion::result_of::at_c::type + >::type::value_type + > type; + }; + + typedef typename loop<0, boost::fusion::result_of::size::value>::type type; +}; + +struct type_iterator { + int pos; + std::function f; + + template + type_iterator(Function f) : pos(0), f(f) {} + + template + void operator()(T) { + f(pos++, type_name()); + } +}; + +template +void print_types(std::ostringstream &s) { + boost::mpl::for_each( + type_iterator([&](int, std::string t) { s << "_" << t; }) + ); +} + //--------------------------------------------------------------------------- // Memory transfer functions //--------------------------------------------------------------------------- @@ -338,53 +410,94 @@ void transfer_functions(backend::source_generator &src) { template std::string swap_function() { std::ostringstream s; - s << "swap_" << type_name(); + s << "swap"; + print_types(s); return s.str(); } +template class Address, bool Const = false> +struct pointer_param { + backend::source_generator &src; + const char *name; + int pos; + + pointer_param(backend::source_generator &src, const char *name) + : src(src), name(name), pos(0) {} + + template + void operator()(T) { + if (Const) + src.template parameter< Address >(name) << pos++; + else + src.template parameter< Address >(name) << pos++; + } +}; + template void swap_function(backend::source_generator &src) { - src.function(swap_function()) - .open("(") - .template parameter< regstr_ptr >("a") - .template parameter< regstr_ptr >("b") - .close(")").open("{"); + src.function(swap_function()).open("("); + + boost::mpl::for_each( pointer_param(src, "a") ); + boost::mpl::for_each( pointer_param(src, "b") ); + + src.close(")").open("{"); - src.new_line() << type_name() << " c = *a;"; - src.new_line() << "*a = *b;"; - src.new_line() << "*b = c;"; + boost::mpl::for_each( type_iterator([&](size_t pos, std::string tname) { + src.open("{"); + src.new_line() << tname << " c = *a" << pos << ";"; + src.new_line() << "*a" << pos << " = *b" << pos << ";"; + src.new_line() << "*b" << pos << " = c;"; + src.close("}"); + }) ); src.close("}"); } //--------------------------------------------------------------------------- -template +template std::string odd_even_transpose_sort() { std::ostringstream s; - s << "odd_even_transpose_sort_" << VT << "_" << type_name(); - if (HasValues) s << "_" << type_name(); + s << "odd_even_transpose_sort_" << VT; + print_types(s); + print_types(s); return s.str(); } -template +template void odd_even_transpose_sort(backend::source_generator &src) { swap_function(src); - if (HasValues && !std::is_same::value) swap_function(src); + if (boost::mpl::size::value && !std::is_same::value) + swap_function(src); - src.function(odd_even_transpose_sort()); + src.function(odd_even_transpose_sort()); src.open("("); - src.template parameter< regstr_ptr >("keys"); - if (HasValues) - src.template parameter< regstr_ptr >("vals"); + boost::mpl::for_each( pointer_param(src, "keys") ); + boost::mpl::for_each( pointer_param(src, "vals") ); src.close(")").open("{"); for(int I = 0; I < VT; ++I) { for(int i = 1 & I; i < VT - 1; i += 2) { - src.new_line() << "if (comp(keys[" << i + 1 << "], keys[" << i << "]))"; + src.new_line() << "if (comp("; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << (p ? ", " : "") << "keys" << p << "[" << i + 1 << "]"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys" << p << "[" << i << "]"; + src << "))"; src.open("{"); - src.new_line() << swap_function() << "(keys + " << i << ", keys + " << i + 1 << ");"; - if (HasValues) - src.new_line() << swap_function() << "(vals + " << i << ", vals + " << i + 1 << ");"; + src.new_line() << swap_function() << "("; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << (p ? ", " : "") << "keys" << p << " + " << i; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys" << p << " + " << i + 1; + src << ");"; + if (boost::mpl::size::value) { + src.new_line() << swap_function() << "("; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << (p ? ", " : "") << "vals" << p << " + " << i; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", vals" << p << " + " << i + 1; + src << ");"; + } src.close("}"); } } @@ -393,27 +506,26 @@ void odd_even_transpose_sort(backend::source_generator &src) { } //--------------------------------------------------------------------------- -template +template class Address, typename T> std::string merge_path() { - typedef typename std::decay::type>::type K; - std::ostringstream s; - s << "merge_path_" << type_name(); + s << "merge_path"; + print_types(s); return s.str(); } -template +template class Address, typename T> void merge_path(backend::source_generator &src) { - src.function(merge_path()) + src.function(merge_path()) .open("(") - .template parameter< T >("a") .template parameter< int >("a_count") - .template parameter< T >("b") .template parameter< int >("b_count") - .template parameter< int >("diag") - .close(")").open("{"); + .template parameter< int >("diag"); + + boost::mpl::for_each( pointer_param(src, "a") ); + boost::mpl::for_each( pointer_param(src, "b") ); - typedef typename std::decay::type>::type K; + src.close(")").open("{"); src.new_line() << "int begin = max(0, diag - b_count);"; src.new_line() << "int end = min(diag, a_count);"; @@ -421,9 +533,12 @@ void merge_path(backend::source_generator &src) { src.new_line() << "while (begin < end)"; src.open("{"); src.new_line() << "int mid = (begin + end) >> 1;"; - src.new_line() << type_name() << " a_key = a[mid];"; - src.new_line() << type_name() << " b_key = b[diag - 1 - mid];"; - src.new_line() << "if ( !comp(b_key, a_key) ) begin = mid + 1;"; + src.new_line() << "if ( !comp("; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << (p ? ", " : "") << "b" << p << "[diag - 1 - mid]"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", a" << p << "[mid]"; + src << ") ) begin = mid + 1;"; src.new_line() << "else end = mid;"; src.close("}"); @@ -436,7 +551,8 @@ void merge_path(backend::source_generator &src) { template std::string serial_merge() { std::ostringstream s; - s << "serial_merge_" << VT << "_" << type_name(); + s << "serial_merge_" << VT; + print_types(s); return s.str(); } @@ -444,27 +560,51 @@ template void serial_merge(backend::source_generator &src) { src.function(serial_merge()) .open("(") - .template parameter< shared_ptr >("keys_shared") .template parameter< int >("a_begin") .template parameter< int >("a_end") .template parameter< int >("b_begin") .template parameter< int >("b_end") - .template parameter< regstr_ptr >("results") - .template parameter< regstr_ptr >("indices") - .close(")").open("{"); + .template parameter< regstr_ptr >("indices"); + + boost::mpl::for_each( pointer_param(src, "keys_shared") ); + boost::mpl::for_each( pointer_param(src, "results") ); + + src.close(")").open("{"); + + boost::mpl::for_each( type_iterator([&](size_t pos, std::string tname) { + src.new_line() << tname << " a_key" << pos << " = keys_shared" << pos << "[a_begin];"; + src.new_line() << tname << " b_key" << pos << " = keys_shared" << pos << "[b_begin];"; + }) ); - src.new_line() << type_name() << " a_key = keys_shared[a_begin];"; - src.new_line() << type_name() << " b_key = keys_shared[b_begin];"; src.new_line() << "bool p;"; for(int i = 0; i < VT; ++i) { - src.new_line() << "p = (b_begin >= b_end) || ((a_begin < a_end) && !comp(b_key, a_key));"; + src.new_line() << "p = (b_begin >= b_end) || ((a_begin < a_end) && !comp("; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << (p ? ", " : "") << "b_key" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", a_key" << p; + src << "));"; + + for(int p = 0; p < boost::mpl::size::value; ++p) + src.new_line() + << "results" << p << "[" << i << "] = " + << "p ? a_key" << p << " : b_key" << p << ";"; - src.new_line() << "results[" << i << "] = p ? a_key : b_key;"; src.new_line() << "indices[" << i << "] = p ? a_begin : b_begin;"; - src.new_line() << "if(p) a_key = keys_shared[++a_begin];"; - src.new_line() << "else b_key = keys_shared[++b_begin];"; + src.new_line() << "if(p)"; + src.open("{"); + src.new_line() << "++a_begin;"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src.new_line() << "a_key" << p << " = keys_shared" << p << "[a_begin];"; + src.close("}"); + src.new_line() << "else"; + src.open("{"); + src.new_line() << "++b_begin;"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src.new_line() << "b_key" << p << " = keys_shared" << p << "[b_begin];"; + src.close("}"); } src.new_line().barrier(); @@ -476,23 +616,26 @@ void serial_merge(backend::source_generator &src) { template std::string block_sort_pass() { std::ostringstream s; - s << "block_sort_pass_" << NT << "_" << VT << "_" << type_name(); + s << "block_sort_pass_" << NT << "_" << VT; + print_types(s); return s.str(); } template void block_sort_pass(backend::source_generator &src) { - merge_path< shared_ptr >(src); + merge_path< shared_ptr, T >(src); src.function(block_sort_pass()) .open("(") - .template parameter< shared_ptr >("keys_shared") .template parameter< int >("tid") .template parameter< int >("count") .template parameter< int >("coop") - .template parameter< regstr_ptr >("keys") - .template parameter< regstr_ptr >("indices") - .close(")").open("{"); + .template parameter< regstr_ptr >("indices"); + + boost::mpl::for_each( pointer_param(src, "keys_shared") ); + boost::mpl::for_each( pointer_param(src, "keys") ); + + src.close(")").open("{"); src.new_line() << "int list = ~(coop - 1) & tid;"; src.new_line() << "int diag = min(count, " << VT << " * ((coop - 1) & tid));"; @@ -501,8 +644,20 @@ void block_sort_pass(backend::source_generator &src) { src.new_line() << "int b0 = min(count, start + " << VT << " * (coop / 2));"; src.new_line() << "int b1 = min(count, start + " << VT << " * coop);"; - src.new_line() << "int p = " << merge_path< shared_ptr >() << "(keys_shared + a0, b0 - a0, keys_shared + b0, b1 - b0, diag);"; - src.new_line() << serial_merge() << "(keys_shared, a0 + p, b0, b0 + diag - p, b1, keys, indices);"; + src.new_line() << "int p = " + << merge_path< shared_ptr, T >() + << "(b0 - a0, b1 - b0, diag"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys_shared" << p << " + a0"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys_shared" << p << " + b0"; + src << ");"; + src.new_line() << serial_merge() << "(a0 + p, b0, b0 + diag - p, b1, indices"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys_shared" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys" << p; + src << ");"; src.close("}"); } @@ -511,7 +666,8 @@ void block_sort_pass(backend::source_generator &src) { template std::string gather() { std::ostringstream s; - s << "gather_" << NT << "_" << VT << "_" << type_name(); + s << "gather_" << NT << "_" << VT; + print_types(s); return s.str(); } @@ -519,14 +675,16 @@ template void gather(backend::source_generator &src) { src.function(gather()) .open("(") - .template parameter< shared_ptr >("data") .template parameter< regstr_ptr >("indices") - .template parameter< int >("tid") - .template parameter< regstr_ptr >("reg") - .close(")").open("{"); + .template parameter< int >("tid"); + + boost::mpl::for_each( pointer_param(src, "data") ); + boost::mpl::for_each( pointer_param(src, "reg") ); + src.close(")").open("{");; for(int i = 0; i < VT; ++i) - src.new_line() << "reg[" << i << "] = data[indices[" << i << "]];"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src.new_line() << "reg" << p << "[" << i << "] = data" << p << "[indices[" << i << "]];"; src.new_line().barrier(); @@ -534,100 +692,213 @@ void gather(backend::source_generator &src) { } //--------------------------------------------------------------------------- -template +template std::string block_sort_loop() { std::ostringstream s; - s << "block_sort_loop_" << NT << "_" << VT << "_" << type_name(); - if (HasValues) s << "_" << type_name(); + s << "block_sort_loop_" << NT << "_" << VT; + print_types(s); + print_types(s); return s.str(); } -template -void block_sort_loop(backend::source_generator &src) { - block_sort_pass(src); - if (HasValues) gather(src); +template +struct call_thread_to_shared { + backend::source_generator &src; + const char *tname; + const char *sname; + int pos; - src.function(block_sort_loop()); - src.open("("); - src.template parameter< shared_ptr >("keys_shared"); - if (HasValues) { - src.template parameter< regstr_ptr >("thread_vals"); - src.template parameter< shared_ptr >("vals_shared"); + call_thread_to_shared(backend::source_generator &src, + const char *tname, const char *sname + ) : src(src), tname(tname), sname(sname), pos(0) {} + + template + void operator()(T) { + src.new_line() << thread_to_shared() + << "(" << tname << pos << ", tid, " << sname << pos << ");"; + ++pos; } - src.template parameter< int >("tid"); - src.template parameter< int >("count"); +}; + +template +void block_sort_loop(backend::source_generator &src) { + block_sort_pass(src); + if (boost::mpl::size::value) gather(src); + + src.function(block_sort_loop()).open("("); + + src.template parameter< int >("tid"); + src.template parameter< int >("count"); + + boost::mpl::for_each( pointer_param(src, "keys_shared") ); + boost::mpl::for_each( pointer_param(src, "thread_vals") ); + boost::mpl::for_each( pointer_param(src, "vals_shared") ); + src.close(")").open("{"); src.new_line() << "int indices[" << VT << "];"; - src.new_line() << type_name() << " keys[" << VT << "];"; - for(int coop = 2; coop <= NT; coop *= 2) { - src.new_line() << block_sort_pass() - << "(keys_shared, tid, count, " << coop << ", keys, indices);"; + boost::mpl::for_each( type_iterator([&](size_t pos, std::string tname) { + src.new_line() << tname << " keys" << pos++ << "[" << VT << "];"; + }) ); - if (HasValues) { + for(int coop = 2; coop <= NT; coop *= 2) { + src.new_line() << block_sort_pass() + << "(tid, count, " << coop << ", indices"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys_shared" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys" << p; + src << ");"; + + if (boost::mpl::size::value) { // Exchange the values through shared memory. - src.new_line() << thread_to_shared() - << "(thread_vals, tid, vals_shared);"; + boost::mpl::for_each( call_thread_to_shared(src, "thread_vals", "vals_shared") ); + src.new_line() << gather() - << "(vals_shared, indices, tid, thread_vals);"; + << "(indices, tid"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", vals_shared" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", thread_vals" << p; + src << ");"; } // Store results in shared memory in sorted order. - src.new_line() << thread_to_shared() - << "(keys, tid, keys_shared);"; + boost::mpl::for_each( call_thread_to_shared(src, "keys", "keys_shared") ); } src.close("}"); } //--------------------------------------------------------------------------- -template +template std::string mergesort() { std::ostringstream s; - s << "mergesort_" << NT << "_" << VT << "_" << type_name(); - if (HasValues) s << "_" << type_name(); + s << "mergesort_" << NT << "_" << VT; + print_types(s); + print_types(s); return s.str(); } -template +template void mergesort(backend::source_generator &src) { - odd_even_transpose_sort(src); - block_sort_loop(src); + odd_even_transpose_sort(src); + block_sort_loop(src); + + src.function(mergesort()).open("("); + + src.template parameter< int >("count"); + src.template parameter< int >("tid"); + + boost::mpl::for_each( pointer_param(src, "thread_keys") ); + boost::mpl::for_each( pointer_param(src, "keys_shared") ); + boost::mpl::for_each( pointer_param(src, "thread_vals") ); + boost::mpl::for_each( pointer_param(src, "vals_shared") ); - src.function(mergesort()); - src.open("("); - src.template parameter< regstr_ptr >("thread_keys"); - src.template parameter< shared_ptr >("keys_shared"); - if (HasValues) { - src.template parameter< regstr_ptr >("thread_vals"); - src.template parameter< shared_ptr >("vals_shared"); - } - src.template parameter< int >("count"); - src.template parameter< int >("tid"); src.close(")").open("{"); // Stable sort the keys in the thread. src.new_line() << "if(" << VT << " * tid < count) " - << odd_even_transpose_sort() - << "(thread_keys" - << (HasValues ? ", thread_vals);" : ");"); + << odd_even_transpose_sort() + << "("; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << (p ? ", " : "") << "thread_keys" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", thread_vals" << p; + src << ");"; // Store the locally sorted keys into shared memory. - src.new_line() << thread_to_shared() - << "(thread_keys, tid, keys_shared);"; + boost::mpl::for_each( call_thread_to_shared(src, "thread_keys", "keys_shared") ); // Recursively merge lists until the entire CTA is sorted. - src.new_line() << block_sort_loop() - << "(keys_shared, " - << (HasValues ? "thread_vals, vals_shared, " : "") - << "tid, count);"; + src.new_line() << block_sort_loop() + << "(tid, count"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys_shared" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", thread_vals" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", vals_shared" << p; + src << ");"; src.close("}"); } +template +struct define_transfer_functions { + backend::source_generator &src; + + define_transfer_functions(backend::source_generator &src) : src(src) {} + + template + void operator()(T) const { + transfer_functions(src); + } +}; + +template +struct call_global_to_shared { + backend::source_generator &src; + const char *gname; + const char *sname; + int pos; + + call_global_to_shared(backend::source_generator &src, + const char *gname, const char *sname + ) : src(src), gname(gname), sname(sname), pos(0) {} + + template + void operator()(T) { + src.new_line() << global_to_shared() << "(count2, " + << gname << pos << " + gid, tid, " + << sname << pos << ");"; + ++pos; + } +}; + +template +struct call_shared_to_global { + backend::source_generator &src; + const char *count; + const char *sname; + const char *gname; + const char *gid; + int pos; + + call_shared_to_global(backend::source_generator &src, + const char *count, const char *sname, const char *gname, const char *gid + ) : src(src), count(count), sname(sname), gname(gname), gid(gid), pos(0) {} + + template + void operator()(T) { + src.new_line() << shared_to_global() + << "(" << count << ", " << sname << pos << ", tid, " << gname << pos << " + " << gid << ");"; + ++pos; + } +}; + +template +struct call_shared_to_thread { + backend::source_generator &src; + const char *sname; + const char *tname; + int pos; + + call_shared_to_thread(backend::source_generator &src, + const char *sname, const char *tname + ) : src(src), sname(sname), tname(tname), pos(0) {} + + template + void operator()(T) { + src.new_line() << shared_to_thread() + << "(" << sname << pos << ", tid, " << tname << pos << ");"; + ++pos; + } +}; + //--------------------------------------------------------------------------- -template +template backend::kernel& block_sort_kernel(const backend::command_queue &queue) { static detail::kernel_cache cache; @@ -639,35 +910,54 @@ backend::kernel& block_sort_kernel(const backend::command_queue &queue) { Comp::define(src, "comp"); - transfer_functions(src); - - if (!std::is_same::value) - transfer_functions(src); - - if (HasValues && !std::is_same::value && !std::is_same::value) - transfer_functions(src); - - serial_merge(src); - mergesort(src); + boost::mpl::for_each< + typename boost::mpl::copy< + typename boost::mpl::copy< + V, + boost::mpl::back_inserter + >::type, + boost::mpl::inserter< + boost::mpl::set, + boost::mpl::insert + > + >::type + >( define_transfer_functions(src) ); + + serial_merge(src); + mergesort(src); src.kernel("block_sort"); src.open("("); - src.template parameter< int >("count"); - src.template parameter< global_ptr >("keys_src"); - src.template parameter< global_ptr< K> >("keys_dst"); - if (HasValues) { - src.template parameter< global_ptr >("vals_src"); - src.template parameter< global_ptr< V> >("vals_dst"); - } + src.template parameter< int >("count"); + + boost::mpl::for_each( pointer_param(src, "keys_src") ); + boost::mpl::for_each( pointer_param(src, "keys_dst") ); + boost::mpl::for_each( pointer_param(src, "vals_src") ); + boost::mpl::for_each( pointer_param(src, "vals_dst") ); + src.close(")").open("{"); const int NV = NT * VT; src.new_line() << "union Shared"; src.open("{"); - src.new_line() << type_name() << " keys[" << NT * (VT + 1) << "];"; - if (HasValues) - src.new_line() << type_name() << " vals[" << NV << "];"; + + src.new_line() << "struct"; + src.open("{"); + boost::mpl::for_each( type_iterator([&](size_t pos, std::string tname) { + src.new_line() << tname << " keys" << pos << "[" << NT * (VT + 1) << "];"; + }) ); + src.close("};"); + + if (boost::mpl::size::value) { + src.new_line() << "struct"; + src.open("{"); + boost::mpl::for_each( type_iterator([&](size_t pos, std::string tname) { + src.new_line() << tname << " vals" << pos << "[" << NT * VT << "];"; + }) ); + src.close("};"); + } + src.close("};"); src.smem_static_var("union Shared", "shared"); @@ -678,20 +968,20 @@ backend::kernel& block_sort_kernel(const backend::command_queue &queue) { src.new_line() << "int count2 = min(" << NV << ", count - gid);"; // Load the values into thread order. - if (HasValues) { - src.new_line() << type_name() << " thread_vals[" << VT << "];"; - src.new_line() << global_to_shared() - << "(count2, vals_src + gid, tid, shared.vals);"; - src.new_line() << shared_to_thread() - << "(shared.vals, tid, thread_vals);"; - } + boost::mpl::for_each( type_iterator([&](size_t pos, std::string tname) { + src.new_line() << tname << " thread_vals" << pos << "[" << VT << "];"; + }) ); + + boost::mpl::for_each( call_global_to_shared(src, "vals_src", "shared.vals") ); + boost::mpl::for_each( call_shared_to_thread(src, "shared.vals", "thread_vals") ); // Load keys into shared memory and transpose into register in thread order. - src.new_line() << type_name() << " thread_keys[" << VT << "];"; - src.new_line() << global_to_shared() - << "(count2, keys_src + gid, tid, shared.keys);"; - src.new_line() << shared_to_thread() - << "(shared.keys, tid, thread_keys);"; + boost::mpl::for_each( type_iterator([&](size_t pos, std::string tname) { + src.new_line() << tname << " thread_keys" << pos << "[" << VT << "];"; + }) ); + + boost::mpl::for_each( call_global_to_shared(src, "keys_src", "shared.keys") ); + boost::mpl::for_each( call_shared_to_thread(src, "shared.keys", "thread_keys") ); // If we're in the last tile, set the uninitialized keys for the thread with // a partial number of keys. @@ -699,37 +989,53 @@ backend::kernel& block_sort_kernel(const backend::command_queue &queue) { src.new_line() << "if(first + " << VT << " > count2 && first < count2)"; src.open("{"); - src.new_line() << type_name() << " max_key = thread_keys[0];"; + boost::mpl::for_each( type_iterator([&](size_t pos, std::string tname) { + src.new_line() << tname << " max_key" << pos << " = thread_keys" << pos << "[0];"; + }) ); - for(int i = 1; i < VT; ++i) + for(int i = 1; i < VT; ++i) { src.new_line() - << "if(first + " << i << " < count2)" - << " max_key = comp(max_key, thread_keys[" << i << "])" - << " ? thread_keys[" << i << "] : max_key;"; + << "if(first + " << i << " < count2 && comp("; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << (p ? ", " : "") << "max_key" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", thread_keys" << p << "[" << i << "]"; + src << ") )"; + src.open("{"); + for(int p = 0; p < boost::mpl::size::value; ++p) + src.new_line() << "max_key" << p << " = thread_keys" << p << "[" << i << "];"; + src.close("}"); + } // Fill in the uninitialized elements with max key. - for(int i = 0; i < VT; ++i) + for(int i = 0; i < VT; ++i) { src.new_line() - << "if(first + " << i << " >= count2)" - << " thread_keys[" << i << "] = max_key;"; + << "if(first + " << i << " >= count2)"; + src.open("{"); + for(int p = 0; p < boost::mpl::size::value; ++p) + src.new_line() << "thread_keys" << p << "[" << i << "] = max_key" << p << ";"; + src.close("}"); + } src.close("}"); - src.new_line() << mergesort() - << "(thread_keys, shared.keys, " - << (HasValues ? "thread_vals, shared.vals, " : "") - << "count2, tid);"; + src.new_line() << mergesort() + << "(count2, tid"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", thread_keys" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", shared.keys" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", thread_vals" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", shared.vals" << p; + src << ");"; // Store the sorted keys to global. - src.new_line() << shared_to_global() - << "(count2, shared.keys, tid, keys_dst + gid);"; - - if (HasValues) { - src.new_line() << thread_to_shared() - << "(thread_vals, tid, shared.vals);"; - src.new_line() << shared_to_global() - << "(count2, shared.vals, tid, vals_dst + gid);"; - } + boost::mpl::for_each( call_shared_to_global(src, "count2", "shared.keys", "keys_dst", "gid") ); + + boost::mpl::for_each( call_thread_to_shared(src, "thread_vals", "shared.vals") ); + boost::mpl::for_each( call_shared_to_global(src, "count2", "shared.vals", "vals_dst", "gid") ); src.close("}"); @@ -777,20 +1083,22 @@ backend::kernel merge_partition_kernel(const backend::command_queue &queue) { backend::source_generator src(queue); Comp::define(src, "comp"); - merge_path< global_ptr >(src); + merge_path< global_ptr, T >(src); find_mergesort_frame(src); src.kernel("merge_partition") .open("(") - .template parameter< global_ptr >("a_global") .template parameter< int >("a_count") - .template parameter< global_ptr >("b_global") .template parameter< int >("b_count") .template parameter< int >("nv") .template parameter< int >("coop") .template parameter< global_ptr >("mp_global") - .template parameter< int >("num_searches") - .close(")").open("{"); + .template parameter< int >("num_searches"); + + boost::mpl::for_each( pointer_param(src, "a_global") ); + boost::mpl::for_each( pointer_param(src, "b_global") ); + + src.close(")").open("{"); src.new_line() << "int partition = " << src.global_id(0) << ";"; src.new_line() << "if (partition < num_searches)"; @@ -810,8 +1118,13 @@ backend::kernel merge_partition_kernel(const backend::command_queue &queue) { src.new_line() << "gid -= a0;"; src.close("}"); - src.new_line() << "int mp = " << merge_path< global_ptr >() - << "(a_global + a0, a_count, b_global + b0, b_count, min(gid, a_count + b_count));"; + src.new_line() << "int mp = " << merge_path< global_ptr, T >() + << "(a_count, b_count, min(gid, a_count + b_count)"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", a_global" << p << " + a0"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", b_global" << p << " + b0"; + src << ");"; src.new_line() << "mp_global[partition] = mp;"; src.close("}"); @@ -825,14 +1138,97 @@ backend::kernel merge_partition_kernel(const backend::command_queue &queue) { return kernel->second; } +template +struct temp_storage { + device_vector< typename boost::mpl::at_c::type > head; + temp_storage tail; + + temp_storage(const backend::command_queue &queue, size_t n) + : head(queue, n), tail(queue, n) {} + + template + typename std::enable_if< + (J > I), + device_vector< typename boost::mpl::at_c::type > + >::type& + get() { + return tail.template get(); + } + + template + typename std::enable_if< + (J == I), + device_vector< typename boost::mpl::at_c::type > + >::type& + get() { + return head; + } + + template + typename std::enable_if< + boost::fusion::result_of::size::value == boost::mpl::size::value && + I + 1 < static_cast(boost::mpl::size::value), + void + >::type + swap(Tuple &t) { + std::swap(head, boost::fusion::at_c(t)); + tail.swap(t); + } + + template + typename std::enable_if< + boost::fusion::result_of::size::value == boost::mpl::size::value && + I + 1 == boost::mpl::size::value, + void + >::type + swap(Tuple &t) { + std::swap(head, boost::fusion::at_c(t)); + } +}; + + +template +struct temp_storage::value>::type > +{ + temp_storage(const backend::command_queue&, size_t) {} +}; + +template +struct arg_pusher { + template + static void push(backend::kernel &krn, T &&t) { + krn.push_arg(boost::fusion::at_c(t)); + arg_pusher::push(krn, std::forward(t)); + } + + template + static void push(backend::kernel &krn, temp_storage &t) { + krn.push_arg(t.template get()); + arg_pusher::push(krn, t); + } +}; + +template +struct arg_pusher::type> { + template static void push(backend::kernel&, T&&) { } +}; + +template +void push_args(backend::kernel &krn, T &&t) { + arg_pusher<0, N>::push(krn, t); +} + //--------------------------------------------------------------------------- -template +template backend::device_vector merge_path_partitions( const backend::command_queue &queue, - const device_vector &keys, + const KT &keys, int count, int nv, int coop ) { + typedef typename extract_value_types::type K; + const int NT = 64; int num_partitions = (count + nv - 1) / nv; @@ -840,20 +1236,21 @@ backend::device_vector merge_path_partitions( backend::device_vector partitions(queue, num_partitions + 1); - auto merge_partition = merge_partition_kernel(queue); + auto merge_partition = merge_partition_kernel(queue); - int a_count = keys.size(); + int a_count = boost::fusion::at_c<0>(keys).size(); int b_count = 0; - merge_partition.push_arg(keys); merge_partition.push_arg(a_count); - merge_partition.push_arg(keys); merge_partition.push_arg(b_count); merge_partition.push_arg(nv); merge_partition.push_arg(coop); merge_partition.push_arg(partitions); merge_partition.push_arg(num_partitions + 1); + push_args::value>(merge_partition, keys); + push_args::value>(merge_partition, keys); + merge_partition.config(num_partition_blocks, NT); merge_partition(queue); @@ -1028,28 +1425,70 @@ void load2_to_shared(backend::source_generator &src) { template std::string merge_keys_indices() { std::ostringstream s; - s << "merge_keys_indices_" << NT << "_" << VT << "_" << type_name(); + s << "merge_keys_indices_" << NT << "_" << VT; + print_types(s); return s.str(); } + +template +struct define_load2_to_shared { + backend::source_generator &src; + + define_load2_to_shared(backend::source_generator &src) : src(src) {} + + template + void operator()(T) const { + load2_to_shared(src); + } +}; + +template +struct call_load2_to_shared { + backend::source_generator &src; + int pos; + + call_load2_to_shared(backend::source_generator &src) : src(src), pos(0) {} + + template + void operator()(T) { + src.new_line() << load2_to_shared() + << "(a_global" << pos << " + a0, a_count, b_global" << pos + << " + b0, b_count, tid, keys_shared" << pos << ");"; + + ++pos; + } +}; + template void merge_keys_indices(backend::source_generator &src) { - serial_merge(src); - merge_path< shared_ptr >(src); - load2_to_shared(src); + serial_merge(src); + merge_path< shared_ptr, T >(src); + + boost::mpl::for_each< + typename boost::mpl::copy< + T, + boost::mpl::inserter< + boost::mpl::set0<>, + boost::mpl::insert + > + >::type + >( define_load2_to_shared(src) ); src.function(merge_keys_indices()) .open("(") - .template parameter< global_ptr >("a_global") .template parameter< int >("a_count") - .template parameter< global_ptr >("b_global") .template parameter< int >("b_count") .template parameter< cl_int4 >("range") .template parameter< int >("tid") - .template parameter< shared_ptr >("keys_shared") - .template parameter< regstr_ptr >("results") - .template parameter< regstr_ptr >("indices") - .close(")").open("{"); + .template parameter< regstr_ptr >("indices"); + + boost::mpl::for_each( pointer_param(src, "a_global") ); + boost::mpl::for_each( pointer_param(src, "b_global") ); + boost::mpl::for_each( pointer_param(src, "keys_shared") ); + boost::mpl::for_each( pointer_param(src, "results") ); + + src.close(")").open("{");; src.new_line() << "int a0 = range.x;"; src.new_line() << "int a1 = range.y;"; @@ -1062,14 +1501,18 @@ void merge_keys_indices(backend::source_generator &src) { src.new_line() << "b_count = b1 - b0;"; // Load the data into shared memory. - src.new_line() << load2_to_shared() - << "(a_global + a0, a_count, b_global + b0, b_count, tid, keys_shared);"; + boost::mpl::for_each( call_load2_to_shared(src) ); // Run a merge path to find the start of the serial merge for each // thread. src.new_line() << "int diag = " << VT << " * tid;"; - src.new_line() << "int mp = " << merge_path< shared_ptr >() - << "(keys_shared, a_count, keys_shared + a_count, b_count, diag);"; + src.new_line() << "int mp = " << merge_path< shared_ptr, T >() + << "(a_count, b_count, diag"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys_shared" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys_shared" << p << " + a_count"; + src << ");"; // Compute the ranges of the sources in shared memory. src.new_line() << "int a0tid = mp;"; @@ -1078,8 +1521,13 @@ void merge_keys_indices(backend::source_generator &src) { src.new_line() << "int b1tid = a_count + b_count;"; // Serial merge into register. - src.new_line() << serial_merge() - << "(keys_shared, a0tid, a1tid, b0tid, b1tid, results, indices);"; + src.new_line() << serial_merge() + << "(a0tid, a1tid, b0tid, b1tid, indices"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys_shared" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", results" << p; + src << ");"; src.close("}"); } @@ -1088,7 +1536,8 @@ void merge_keys_indices(backend::source_generator &src) { template std::string transfer_merge_values_regstr() { std::ostringstream s; - s << "transfer_merge_values_regstr_" << NT << "_" << VT << "_" << type_name(); + s << "transfer_merge_values_regstr_" << NT << "_" << VT; + print_types(s); return s.str(); } @@ -1097,21 +1546,33 @@ void transfer_merge_values_regstr(backend::source_generator &src) { src.function(transfer_merge_values_regstr()) .open("(") .template parameter< int >("count") - .template parameter< global_ptr >("a_global") - .template parameter< global_ptr >("b_global") .template parameter< int >("b_start") .template parameter< regstr_ptr >("indices") - .template parameter< int >("tid") - .template parameter< regstr_ptr >("reg") - .close(")").open("{"); + .template parameter< int >("tid"); + + boost::mpl::for_each( pointer_param(src, "a_global") ); + boost::mpl::for_each( pointer_param(src, "b_global") ); + boost::mpl::for_each( pointer_param(src, "reg") ); + + src.close(")").open("{"); - src.new_line() << "b_global -= b_start;"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src.new_line() << "b_global" << p << " -= b_start;"; src.new_line() << "if(count >= " << NT * VT << ")"; src.open("{"); - for(int i = 0; i < VT; ++i) - src.new_line() << "reg[" << i << "] = (indices[" << i << "] < b_start) " - "? a_global[indices[" << i << "]] : b_global[indices[" << i << "]];"; + for(int i = 0; i < VT; ++i) { + src.new_line() << "if (indices[" << i << "] < b_start)"; + src.open("{"); + for(int p = 0; p < boost::mpl::size::value; ++p) + src.new_line() << "reg" << p << "[" << i << "] = a_global" << p << "[indices[" << i << "]];"; + src.close("}"); + src.new_line() << "else"; + src.open("{"); + for(int p = 0; p < boost::mpl::size::value; ++p) + src.new_line() << "reg" << p << "[" << i << "] = b_global" << p << "[indices[" << i << "]];"; + src.close("}"); + } src.close("}"); src.new_line() << "else"; @@ -1121,10 +1582,19 @@ void transfer_merge_values_regstr(backend::source_generator &src) { for(int i = 0; i < VT; ++i) { src.new_line() << "index = " << NT * i << " + tid;"; - src.new_line() << - "if(index < count) " - "reg[" << i << "] = (indices[" << i << "] < b_start) ? " - "a_global[indices[" << i << "]] : b_global[indices[" << i << "]];"; + src.new_line() << "if(index < count)"; + src.open("{"); + src.new_line() << "if (indices[" << i << "] < b_start)"; + src.open("{"); + for(int p = 0; p < boost::mpl::size::value; ++p) + src.new_line() << "reg" << p << "[" << i << "] = a_global" << p << "[indices[" << i << "]];"; + src.close("}"); + src.new_line() << "else"; + src.open("{"); + for(int p = 0; p < boost::mpl::size::value; ++p) + src.new_line() << "reg" << p << "[" << i << "] = b_global" << p << "[indices[" << i << "]];"; + src.close("}"); + src.close("}"); } src.close("}"); @@ -1137,105 +1607,157 @@ void transfer_merge_values_regstr(backend::source_generator &src) { template std::string transfer_merge_values_shared() { std::ostringstream s; - s << "transfer_merge_values_shared_" << NT << "_" << VT << "_" << type_name(); + s << "transfer_merge_values_shared_" << NT << "_" << VT; + print_types(s); return s.str(); } +template +struct call_regstr_to_global { + backend::source_generator &src; + int pos; + + call_regstr_to_global(backend::source_generator &src) : src(src), pos(0) {} + + template + void operator()(T) { + src.new_line() << regstr_to_global() + << "(count, reg" << pos << ", tid, dest_global" << pos << ");"; + ++pos; + } +}; + template void transfer_merge_values_shared(backend::source_generator &src) { - transfer_merge_values_regstr(src); + transfer_merge_values_regstr(src); src.function(transfer_merge_values_shared()) .open("(") .template parameter< int >("count") - .template parameter< global_ptr >("a_global") - .template parameter< global_ptr >("b_global") .template parameter< int >("b_start") .template parameter< shared_ptr >("indices_shared") - .template parameter< int >("tid") - .template parameter< global_ptr >("dest_global") - .close(")").open("{"); + .template parameter< int >("tid"); + + boost::mpl::for_each( pointer_param(src, "a_global") ); + boost::mpl::for_each( pointer_param(src, "b_global") ); + boost::mpl::for_each( pointer_param(src, "dest_global") ); + + src.close(")").open("{"); src.new_line() << "int indices[" << VT << "];"; src.new_line() << shared_to_regstr() << "(indices_shared, tid, indices);"; - src.new_line() << type_name() << " reg[" << VT << "];"; - src.new_line() << transfer_merge_values_regstr() - << "(count, a_global, b_global, b_start, indices, tid, reg);"; - src.new_line() << regstr_to_global() - << "(count, reg, tid, dest_global);"; + boost::mpl::for_each( type_iterator([&](size_t pos, std::string tname) { + src.new_line() << tname << " reg" << pos++ << "[" << VT << "];"; + }) ); + + src.new_line() << transfer_merge_values_regstr() + << "(count, b_start, indices, tid"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", a_global" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", b_global" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", reg" << p; + src << ");"; + + boost::mpl::for_each( call_regstr_to_global(src) ); src.close("}"); } //--------------------------------------------------------------------------- -template +template std::string device_merge() { std::ostringstream s; - s << "device_merge_" << NT << "_" << VT << "_" << type_name(); - if (HasValues) s << "_" << type_name(); + s << "device_merge_" << NT << "_" << VT; + print_types(s); + print_types(s); return s.str(); } -template +template void device_merge(backend::source_generator &src) { - merge_keys_indices(src); - transfer_merge_values_shared(src); + merge_keys_indices(src); + if (boost::mpl::size::value) + transfer_merge_values_shared(src); - src.function(device_merge()); + src.function(device_merge()); src.open("("); src.template parameter< int >("a_count"); src.template parameter< int >("b_count"); - src.template parameter< global_ptr >("a_keys_global"); - src.template parameter< global_ptr >("b_keys_global"); - src.template parameter< global_ptr >("keys_global"); - src.template parameter< shared_ptr >("keys_shared"); - if (HasValues) { - src.template parameter< global_ptr >("a_vals_global"); - src.template parameter< global_ptr >("b_vals_global"); - src.template parameter< global_ptr >("vals_global"); - } + + boost::mpl::for_each( pointer_param(src, "a_keys_global")); + boost::mpl::for_each( pointer_param(src, "b_keys_global")); + boost::mpl::for_each( pointer_param(src, "keys_global")); + boost::mpl::for_each( pointer_param(src, "keys_shared")); + + boost::mpl::for_each( pointer_param(src, "a_vals_global")); + boost::mpl::for_each( pointer_param(src, "b_vals_global")); + boost::mpl::for_each( pointer_param(src, "vals_global")); + src.template parameter< int >("tid"); src.template parameter< int >("block"); src.template parameter< cl_int4 >("range"); src.template parameter< shared_ptr >("indices_shared"); src.close(")").open("{"); - src.new_line() << type_name() << " results[" << VT << "];"; + boost::mpl::for_each( type_iterator([&](size_t pos, std::string tname) { + src.new_line() << tname << " results" << pos++ << "[" << VT << "];"; + }) ); + src.new_line() << "int indices[" << VT << "];"; - src.new_line() << merge_keys_indices() - << "(a_keys_global, a_count, b_keys_global, b_count, range, tid, " - << "keys_shared, results, indices);"; + src.new_line() << merge_keys_indices() + << "(a_count, b_count, range, tid, indices"; - // Store merge results back to shared memory. - src.new_line() << thread_to_shared() - << "(results, tid, keys_shared);"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", a_keys_global" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", b_keys_global" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys_shared" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", results" << p; + src << ");"; + + // Store merge results back to shared memory. + boost::mpl::for_each( call_thread_to_shared(src, "results", "keys_shared") ); // Store merged keys to global memory. src.new_line() << "a_count = range.y - range.x;"; src.new_line() << "b_count = range.w - range.z;"; - src.new_line() << shared_to_global() - << "(a_count + b_count, keys_shared, tid, keys_global + " - << NT * VT << " * block);"; + + { + std::ostringstream s; + s << NT * VT << " * block"; + boost::mpl::for_each( call_shared_to_global(src, "a_count + b_count", "keys_shared", "keys_global", s.str().c_str() ) ); + } // Copy the values. - if (HasValues) { + if (boost::mpl::size::value) { src.new_line() << thread_to_shared() << "(indices, tid, indices_shared);"; - src.new_line() << transfer_merge_values_shared() - << "(a_count + b_count, a_vals_global + range.x, " - "b_vals_global + range.z, a_count, indices_shared, tid, " - "vals_global + " << NT * VT << " * block);"; + + src.new_line() << transfer_merge_values_shared() + << "(a_count + b_count, a_count, indices_shared, tid"; + + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", a_vals_global" << p << " + range.x"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", b_vals_global" << p << " + range.z"; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", vals_global" << p << " + " << NT * VT << " * block"; + src << ");"; } src.close("}"); } //--------------------------------------------------------------------------- -template +template backend::kernel merge_kernel(const backend::command_queue &queue) { static detail::kernel_cache cache; @@ -1248,28 +1770,34 @@ backend::kernel merge_kernel(const backend::command_queue &queue) { Comp::define(src, "comp"); compute_merge_range(src); - transfer_functions(src); + boost::mpl::for_each< + typename boost::mpl::copy< + typename boost::mpl::copy< + V, + boost::mpl::back_inserter + >::type, + boost::mpl::inserter< + boost::mpl::set, + boost::mpl::insert + > + >::type + >( define_transfer_functions(src) ); + + device_merge(src); - if (!std::is_same::value) - transfer_functions(src); + src.kernel("merge"); + src.open("("); + src.template parameter< int >("a_count"); + src.template parameter< int >("b_count"); - if (HasValues && !std::is_same::value && !std::is_same::value) - transfer_functions(src); + boost::mpl::for_each( pointer_param(src, "a_keys_global")); + boost::mpl::for_each( pointer_param(src, "b_keys_global")); + boost::mpl::for_each( pointer_param(src, "keys_global")); - device_merge(src); + boost::mpl::for_each( pointer_param(src, "a_vals_global")); + boost::mpl::for_each( pointer_param(src, "b_vals_global")); + boost::mpl::for_each( pointer_param(src, "vals_global")); - src.kernel("merge"); - src.open("("); - src.template parameter< int >("a_count"); - src.template parameter< int >("b_count"); - src.template parameter< global_ptr >("a_keys_global"); - src.template parameter< global_ptr >("b_keys_global"); - src.template parameter< global_ptr >("keys_global"); - if (HasValues) { - src.template parameter< global_ptr >("a_vals_global"); - src.template parameter< global_ptr >("b_vals_global"); - src.template parameter< global_ptr >("vals_global"); - } src.template parameter< global_ptr >("mp_global"); src.template parameter< int >("coop"); src.close(")").open("{"); @@ -1278,7 +1806,14 @@ backend::kernel merge_kernel(const backend::command_queue &queue) { src.new_line() << "union Shared"; src.open("{"); - src.new_line() << type_name() << " keys[" << NT * (VT + 1) << "];"; + + src.new_line() << "struct"; + src.open("{"); + boost::mpl::for_each( type_iterator([&](size_t pos, std::string tname) { + src.new_line() << tname << " keys" << pos++ << "[" << NT * (VT + 1) << "];"; + }) ); + src.close("};"); + src.new_line() << "int indices[" << NV << "];"; src.close("};"); @@ -1290,10 +1825,26 @@ backend::kernel merge_kernel(const backend::command_queue &queue) { src.new_line() << "int4 range = compute_merge_range(" "a_count, b_count, block, coop, " << NV << ", mp_global);"; - src.new_line() << device_merge() - << "(a_count, b_count, a_keys_global, b_keys_global, keys_global, shared.keys, " - << (HasValues ? "a_vals_global, b_vals_global, vals_global, " : "") - << "tid, block, range, shared.indices);"; + src.new_line() << device_merge() + << "(a_count, b_count"; + + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", a_keys_global" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", b_keys_global" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", keys_global" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", shared.keys" << p; + + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", a_vals_global" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", b_vals_global" << p; + for(int p = 0; p < boost::mpl::size::value; ++p) + src << ", vals_global" << p; + + src << ", tid, block, range, shared.indices);"; src.close("}"); @@ -1323,45 +1874,61 @@ inline int find_log2(int x, bool round_up = false) { return a; } +template +device_vector< typename boost::mpl::at_c::type >& +at_c(temp_storage &s) { + return s.template get(); +} + /// Sorts single partition of a vector. -template -void sort(const backend::command_queue &queue, - backend::device_vector &keys, Comp) -{ +template +void sort(const backend::command_queue &queue, KT &keys, Comp) { + typedef typename extract_value_types::type K; + using boost::fusion::at_c; + + typedef + typename boost::mpl::accumulate< + K, + boost::mpl::int_<0>, + boost::mpl::plus > + >::type + sizeof_keys; + backend::select_context(queue); const int NT_cpu = 1; const int NT_gpu = 256; const int NT = is_cpu(queue) ? NT_cpu : NT_gpu; - const int VT = (sizeof(K) > 4) ? 7 : 11; + const int VT = (sizeof_keys::value > 4) ? 7 : 11; const int NV = NT * VT; - const int count = keys.size(); + const int count = at_c<0>(keys).size(); const int num_blocks = (count + NV - 1) / NV; const int num_passes = detail::find_log2(num_blocks, true); - - device_vector keys_dst(queue, count); + temp_storage tmp(queue, count); auto block_sort = is_cpu(queue) ? - detail::block_sort_kernel(queue) : - detail::block_sort_kernel(queue); + detail::block_sort_kernel, Comp>(queue) : + detail::block_sort_kernel, Comp>(queue); block_sort.push_arg(count); - block_sort.push_arg(keys); - block_sort.push_arg(1 & num_passes ? keys_dst : keys); + + push_args::value>(block_sort, keys); + if (1 & num_passes) + push_args::value>(block_sort, tmp); + else + push_args::value>(block_sort, keys); block_sort.config(num_blocks, NT); block_sort(queue); - if (1 & num_passes) { - std::swap(keys, keys_dst); - } + if (1 & num_passes) tmp.swap(keys); auto merge = is_cpu(queue) ? - detail::merge_kernel(queue) : - detail::merge_kernel(queue); + detail::merge_kernel, Comp>(queue) : + detail::merge_kernel, Comp>(queue); for(int pass = 0; pass < num_passes; ++pass) { int coop = 2 << pass; @@ -1370,28 +1937,29 @@ void sort(const backend::command_queue &queue, merge.push_arg(count); merge.push_arg(0); - merge.push_arg(keys); - merge.push_arg(keys); - merge.push_arg(keys_dst); + + push_args::value>(merge, keys); + push_args::value>(merge, keys); + push_args::value>(merge, tmp); merge.push_arg(partitions); merge.push_arg(coop); merge.config(num_blocks, NT); merge(queue); - std::swap(keys, keys_dst); + tmp.swap(keys); } } /// Sorts single partition of a vector. -template -void sort_by_key(const backend::command_queue &queue, - backend::device_vector &keys, - backend::device_vector &vals, - Comp - ) -{ - precondition(keys.size() == vals.size(), +template +void sort_by_key(const backend::command_queue &queue, KTup &&keys, VTup &&vals, Comp) { + typedef typename extract_value_types::type K; + typedef typename extract_value_types::type V; + + using boost::fusion::at_c; + + precondition(at_c<0>(keys).size() == at_c<0>(vals).size(), "keys and values should have same size" ); @@ -1403,35 +1971,43 @@ void sort_by_key(const backend::command_queue &queue, const int VT = (sizeof(K) > 4) ? 7 : 11; const int NV = NT * VT; - const int count = keys.size(); + const int count = at_c<0>(keys).size(); const int num_blocks = (count + NV - 1) / NV; const int num_passes = detail::find_log2(num_blocks, true); - device_vector keys_dst(queue, count); - device_vector vals_dst(queue, count); + temp_storage keys_tmp(queue, count); + temp_storage vals_tmp(queue, count); auto block_sort = is_cpu(queue) ? - detail::block_sort_kernel(queue) : - detail::block_sort_kernel(queue); + detail::block_sort_kernel(queue) : + detail::block_sort_kernel(queue); block_sort.push_arg(count); - block_sort.push_arg(keys); - block_sort.push_arg(1 & num_passes ? keys_dst : keys); - block_sort.push_arg(vals); - block_sort.push_arg(1 & num_passes ? vals_dst : vals); + + push_args::value>(block_sort, keys); + if (1 & num_passes) + push_args::value>(block_sort, keys_tmp); + else + push_args::value>(block_sort, keys); + + push_args::value>(block_sort, vals); + if (1 & num_passes) + push_args::value>(block_sort, vals_tmp); + else + push_args::value>(block_sort, vals); block_sort.config(num_blocks, NT); block_sort(queue); if (1 & num_passes) { - std::swap(keys, keys_dst); - std::swap(vals, vals_dst); + keys_tmp.swap(keys); + vals_tmp.swap(vals); } auto merge = is_cpu(queue) ? - detail::merge_kernel(queue) : - detail::merge_kernel(queue); + detail::merge_kernel(queue) : + detail::merge_kernel(queue); for(int pass = 0; pass < num_passes; ++pass) { int coop = 2 << pass; @@ -1440,54 +2016,130 @@ void sort_by_key(const backend::command_queue &queue, merge.push_arg(count); merge.push_arg(0); - merge.push_arg(keys); - merge.push_arg(keys); - merge.push_arg(keys_dst); - merge.push_arg(vals); - merge.push_arg(vals); - merge.push_arg(vals_dst); + + push_args::value>(merge, keys); + push_args::value>(merge, keys); + push_args::value>(merge, keys_tmp); + + push_args::value>(merge, vals); + push_args::value>(merge, vals); + push_args::value>(merge, vals_tmp); + merge.push_arg(partitions); merge.push_arg(coop); merge.config(num_blocks, NT); merge(queue); - std::swap(keys, keys_dst); - std::swap(vals, vals_dst); + keys_tmp.swap(keys); + vals_tmp.swap(vals); } } +template +boost::fusion::zip_view< boost::fusion::vector > +make_zip_view(S1 &s1, S2 &s2) { + typedef boost::fusion::vector Z; + return boost::fusion::zip_view( Z(s1, s2)); +} + +struct do_resize { + size_t n; + do_resize(size_t n) : n(n) {} + template + void operator()(V &v) const { + v.resize(n); + } +}; + +struct do_copy { + template + void operator()(T t) const { + using boost::fusion::at_c; + vex::copy(at_c<0>(t), at_c<1>(t)); + } +}; + +struct copy_element { + size_t dst, src; + copy_element(size_t dst, size_t src) : dst(dst), src(src) {} + + template + void operator()(T t) const { + using boost::fusion::at_c; + at_c<0>(t)[dst] = at_c<1>(t)[src]; + } +}; + +struct do_index { + size_t pos; + do_index(size_t pos) : pos(pos) {} + + template struct result; + + template + struct result< This(T) > { + typedef typename std::decay::type::value_type type; + }; + + template + typename result::type operator()(const T &t) const { + return t[pos]; + } +}; + /// Merges partially sorted vector partitions into host vector -template -std::vector merge(const vector &keys, Comp comp) { - const auto &queue = keys.queue_list(); +template +typename boost::fusion::result_of::as_vector< + typename boost::mpl::transform< + typename extract_value_types::type, + std::vector + >::type +>::type +merge(const KTuple &keys, Comp comp) { + namespace fusion = boost::fusion; - std::vector dst(keys.size()); - vex::copy(keys, dst); + typedef typename extract_value_types::type K; + + const auto &queue = fusion::at_c<0>(keys).queue_list(); + const size_t count = fusion::at_c<0>(keys).size(); + + typedef typename fusion::result_of::as_vector< + typename boost::mpl::transform< K, std::vector >::type + >::type host_vectors; + + host_vectors dst; + + fusion::for_each(dst, do_resize(count) ); + fusion::for_each( make_zip_view(keys, dst), do_copy() ); if (queue.size() > 1) { - std::vector src(keys.size()); - std::swap(src, dst); + host_vectors src; + fusion::for_each(src, do_resize(count) ); + + fusion::swap(src, dst); - typedef typename std::vector::const_iterator iterator; - std::vector begin(queue.size()); - std::vector end(queue.size()); + std::vector begin(queue.size()); + std::vector end (queue.size()); for(unsigned d = 0; d < queue.size(); ++d) { - begin[d] = src.begin() + keys.part_start(d); - end [d] = src.begin() + keys.part_start(d + 1); + begin[d] = fusion::at_c<0>(keys).part_start(d); + end [d] = fusion::at_c<0>(keys).part_start(d + 1); } - for(auto pos = dst.begin(); pos != dst.end(); ++pos) { + for(size_t pos = 0; pos < count; ++pos) { int winner = -1; for(unsigned d = 0; d < queue.size(); ++d) { if (begin[d] == end[d]) continue; - if (winner < 0 || comp(*begin[d], *begin[winner])) + auto curr = fusion::transform(src, do_index(begin[d])); + auto best = fusion::transform(src, do_index(begin[winner])); + + if (winner < 0 || fusion::invoke(comp, fusion::join(curr, best))) winner = d; } - *pos = *begin[winner]++; + fusion::for_each(make_zip_view(dst, src), copy_element(pos, begin[winner]++)); } } @@ -1495,55 +2147,165 @@ std::vector merge(const vector &keys, Comp comp) { } /// Merges partially sorted vector partitions into host vector -template -void merge(const vector &keys, const vector &vals, Comp comp, - std::vector &dst_keys, std::vector &dst_vals) -{ - const auto &queue = keys.queue_list(); - - dst_keys.resize(keys.size()); - dst_vals.resize(keys.size()); - - vex::copy(keys, dst_keys); - vex::copy(vals, dst_vals); +template +typename boost::fusion::result_of::as_vector< + typename boost::mpl::transform< + typename extract_value_types< + typename boost::fusion::result_of::as_vector< + typename boost::fusion::result_of::join::type + >::type + >::type, + std::vector + >::type +>::type +merge(const KTuple &keys, const VTuple &vals, Comp comp) { + namespace fusion = boost::fusion; + + typedef typename extract_value_types::type K; + typedef typename extract_value_types::type V; + + const auto &queue = fusion::at_c<0>(keys).queue_list(); + const size_t count = fusion::at_c<0>(keys).size(); + + typedef typename fusion::result_of::as_vector< + typename boost::mpl::transform< K, std::vector >::type + >::type host_keys; + + typedef typename fusion::result_of::as_vector< + typename boost::mpl::transform< V, std::vector >::type + >::type host_vals; + + host_keys dst_keys; + host_vals dst_vals; + + fusion::for_each(dst_keys, do_resize(count) ); + fusion::for_each(dst_vals, do_resize(count) ); + + fusion::for_each( make_zip_view(keys, dst_keys), do_copy() ); + fusion::for_each( make_zip_view(vals, dst_vals), do_copy() ); if (queue.size() > 1) { - std::vector src_keys(keys.size()); - std::vector src_vals(keys.size()); + host_keys src_keys; + host_vals src_vals; - std::swap(src_keys, dst_keys); - std::swap(src_vals, dst_vals); + fusion::for_each(src_keys, do_resize(count) ); + fusion::for_each(src_vals, do_resize(count) ); - typedef typename std::vector::const_iterator key_iterator; - typedef typename std::vector::const_iterator val_iterator; + fusion::swap(src_keys, dst_keys); + fusion::swap(src_vals, dst_vals); - std::vector key_begin(queue.size()), key_end(queue.size()); - std::vector val_begin(queue.size()), val_end(queue.size()); + std::vector begin(queue.size()), end(queue.size()); for(unsigned d = 0; d < queue.size(); ++d) { - key_begin[d] = src_keys.begin() + keys.part_start(d); - key_end [d] = src_keys.begin() + keys.part_start(d + 1); - - val_begin[d] = src_vals.begin() + keys.part_start(d); - val_end [d] = src_vals.begin() + keys.part_start(d + 1); + begin[d] = fusion::at_c<0>(keys).part_start(d); + end [d] = fusion::at_c<0>(keys).part_start(d + 1); } - auto key_pos = dst_keys.begin(); - auto val_pos = dst_vals.begin(); - - while(key_pos != dst_keys.end()) { + for(size_t pos = 0; pos < count; ++pos) { int winner = -1; for(unsigned d = 0; d < queue.size(); ++d) { - if (key_begin[d] == key_end[d]) continue; + if (begin[d] == end[d]) continue; + + auto curr = fusion::transform(src_keys, do_index(begin[d])); + auto best = fusion::transform(src_keys, do_index(begin[winner])); - if (winner < 0 || comp(*key_begin[d], *key_begin[winner])) + if (winner < 0 || fusion::invoke(comp, fusion::join(curr, best))) winner = d; } - *key_pos++ = *key_begin[winner]++; - *val_pos++ = *val_begin[winner]++; + fusion::for_each(make_zip_view(dst_keys, src_keys), copy_element(pos, begin[winner])); + fusion::for_each(make_zip_view(dst_keys, src_keys), copy_element(pos, begin[winner])); + + ++begin[winner]; } } + + return fusion::as_vector(fusion::join(dst_keys, dst_vals)); +} + +template +typename std::enable_if< + boost::fusion::traits::is_sequence::value, + T& +>::type +forward_as_sequence(T &t) { + return t; +} + +template +typename std::enable_if< + !boost::fusion::traits::is_sequence::value, + boost::fusion::vector +>::type +forward_as_sequence(T &t) { + return boost::fusion::vector(t); +} + +struct extract_device_vector { + uint d; + + extract_device_vector(uint d) : d(d) {} + + template struct result; + + template + struct result< This(T) > { + typedef device_vector::type::value_type>& type; + }; + + template + typename result::type operator()(T &t) const { + return t(d); + } +}; + +template +void sort_sink(K &&keys, Comp comp) { + namespace fusion = boost::fusion; + + const auto &queue = boost::fusion::at_c<0>(keys).queue_list(); + + for(unsigned d = 0; d < queue.size(); ++d) + if (fusion::at_c<0>(keys).part_size(d)) { + auto part = fusion::transform(keys, extract_device_vector(d)); + sort(queue[d], part, comp.device); + } + + if (queue.size() <= 1) return; + + // Vector partitions have been sorted on compute devices. + // Now we need to merge them on a CPU. This is a linear time operation, + // so total performance should be good enough. + auto host_vectors = merge(keys, comp); + fusion::for_each( make_zip_view(host_vectors, keys), do_copy() ); +} + +template +void sort_by_key_sink(K &&keys, V &&vals, Comp comp) { + namespace fusion = boost::fusion; + + precondition( + fusion::at_c<0>(keys).nparts() == fusion::at_c<0>(vals).nparts(), + "Keys and values span different devices" + ); + + const auto &queue = fusion::at_c<0>(keys).queue_list(); + + for(unsigned d = 0; d < queue.size(); ++d) + if (fusion::at_c<0>(keys).part_size(d)) { + auto kpart = fusion::transform(keys, extract_device_vector(d)); + auto vpart = fusion::transform(vals, extract_device_vector(d)); + sort_by_key(queue[d], kpart, vpart, comp.device); + } + + if (queue.size() <= 1) return; + + // Vector partitions have been sorted on compute devices. + // Now we need to merge them on a CPU. This is a linear time operation, + // so total performance should be good enough. + auto host_vectors = merge(keys, vals, comp); + auto dev_vectors = fusion::join(keys, vals); + fusion::for_each( make_zip_view(host_vectors, dev_vectors), do_copy() ); } } // namespace detail @@ -1597,18 +2359,8 @@ struct greater_equal : std::greater_equal { * \param comp comparison function. */ template -void sort(vector &keys, Comp comp) { - const auto &queue = keys.queue_list(); - - for(unsigned d = 0; d < queue.size(); ++d) - if (keys.part_size(d)) detail::sort(queue[d], keys(d), comp.device); - - if (queue.size() <= 1) return; - - // Vector partitions have been sorted on compute devices. - // Now we need to merge them on a CPU. This is a linear time operation, - // so total performance should be good enough. - vex::copy(detail::merge(keys, comp), keys); +void sort(K &&keys, Comp comp) { + detail::sort_sink(detail::forward_as_sequence(keys), comp); } /// Sorts the elements in keys and values into ascending key order. @@ -1622,30 +2374,11 @@ void sort(vector &keys) { * \param comp comparison function. */ template -void sort_by_key(vector &keys, vector &vals, Comp comp) { - precondition( - keys.queue_list().size() == vals.queue_list().size(), - "Keys and values span different devices" - ); - - auto &queue = keys.queue_list(); - - for(unsigned d = 0; d < queue.size(); ++d) - if (keys.part_size(d)) - detail::sort_by_key(queue[d], keys(d), vals(d), comp.device); - - if (queue.size() <= 1) return; - - // Vector partitions have been sorted on compute devices. - // Now we need to merge them on a CPU. This is a linear time operation, - // so total performance should be good enough. - std::vector k; - std::vector v; - - detail::merge(keys, vals, comp, k, v); - - vex::copy(k, keys); - vex::copy(v, vals); +void sort_by_key(K &&keys, V &&vals, Comp comp) { + detail::sort_by_key_sink( + detail::forward_as_sequence(keys), + detail::forward_as_sequence(vals), + comp); } /// Sorts the elements in keys and values into ascending key order.