From e3b83bde5a1f790190cb4778793f43c94c0d4c7f Mon Sep 17 00:00:00 2001 From: Denis Demidov Date: Wed, 2 Apr 2014 12:27:41 +0400 Subject: [PATCH] Switch to VEX_FUNCTION(v2) internally and in tests --- tests/deduce.cpp | 4 +- tests/generator.cpp | 7 +- tests/multivector_arithmetics.cpp | 2 +- tests/reduce_by_key.cpp | 8 +-- tests/sort.cpp | 38 +++-------- tests/temporary.cpp | 2 +- tests/vector_arithmetics.cpp | 103 ++++-------------------------- tests/vector_pointer.cpp | 12 ++-- vexcl/backend/opencl/fft/plan.hpp | 4 +- vexcl/reduce_by_key.hpp | 6 +- vexcl/scan.hpp | 2 +- vexcl/sort.hpp | 8 +-- 12 files changed, 48 insertions(+), 148 deletions(-) diff --git a/tests/deduce.cpp b/tests/deduce.cpp index 4ecbc8164..86388765b 100644 --- a/tests/deduce.cpp +++ b/tests/deduce.cpp @@ -109,8 +109,8 @@ BOOST_AUTO_TEST_CASE(user_functions) vex::vector x; vex::vector y; - VEX_FUNCTION_V1(f1, double(double), "return 42;"); - VEX_FUNCTION_V1(f2, int(double, double), "return 42;"); + VEX_FUNCTION(double, f1, (double, x), return 42;); + VEX_FUNCTION(int, f2, (double, x)(double, y), return 42;); check( f1(x) ); check ( f2(x, y) ); diff --git a/tests/generator.cpp b/tests/generator.cpp index fbf18f416..1f1e2eb07 100644 --- a/tests/generator.cpp +++ b/tests/generator.cpp @@ -64,7 +64,10 @@ BOOST_AUTO_TEST_CASE(kernel_generator_with_user_function) sym_state sym_x(sym_state::VectorParameter, sym_state::Const); sym_state sym_y(sym_state::VectorParameter); - VEX_FUNCTION_V1(sin2, double(double), "double s = sin(prm1); return s * s;"); + VEX_FUNCTION(double, sin2, (double, x), + double s = sin(x); + return s * s; + ); sym_y = sin2(sym_x); @@ -100,7 +103,7 @@ BOOST_AUTO_TEST_CASE(function_generator) static std::string function_body = vex::generator::make_function( body.str(), sym_x, sym_x); - VEX_FUNCTION_V1(rk2, double(double), function_body); + VEX_FUNCTION_S(double, rk2, (double, prm1), function_body); std::vector x = random_vector(n); vex::vector X(ctx, x); diff --git a/tests/multivector_arithmetics.cpp b/tests/multivector_arithmetics.cpp index 8e92ebacd..6aea29c42 100644 --- a/tests/multivector_arithmetics.cpp +++ b/tests/multivector_arithmetics.cpp @@ -104,7 +104,7 @@ BOOST_AUTO_TEST_CASE(user_defined_functions) elem_t v1 = {{1, 2}}; elem_t v2 = {{2, 1}}; - VEX_FUNCTION_V1(greater, size_t(double, double), "return prm1 > prm2;"); + VEX_FUNCTION(size_t, greater, (double, x)(double, y), return x > y;); x = v1; y = v2; diff --git a/tests/reduce_by_key.cpp b/tests/reduce_by_key.cpp index 7402f93b0..7288d278a 100644 --- a/tests/reduce_by_key.cpp +++ b/tests/reduce_by_key.cpp @@ -94,12 +94,12 @@ BOOST_AUTO_TEST_CASE(rbk_tuple) vex::vector okey2; vex::vector ovals; - VEX_FUNCTION_V1(equal, bool(cl_int, cl_long, cl_int, cl_long), - "return (prm1 == prm3) && (prm2 == prm4);" + VEX_FUNCTION(bool, equal, (cl_int, a1)(cl_long, a2)(cl_int, b1)(cl_long, b2), + return (a1 == b1) && (a2 == b2); ); - VEX_FUNCTION_V1(plus, double(double, double), - "return prm1 + prm2;" + VEX_FUNCTION(double, plus, (double, x)(double, y), + return x + y; ); int num_keys = vex::reduce_by_key( diff --git a/tests/sort.cpp b/tests/sort.cpp index db4951268..bfacb620e 100644 --- a/tests/sort.cpp +++ b/tests/sort.cpp @@ -59,24 +59,14 @@ BOOST_AUTO_TEST_CASE(sort_keys_vals_custom_op) struct even_first_t { typedef bool result_type; + even_first_t() {} - VEX_FUNCTION_V1(device, bool(int, int), - VEX_STRINGIZE_SOURCE( - char bit1 = 1 & prm1; - char bit2 = 1 & prm2; - if (bit1 == bit2) return prm1 < prm2; - return bit1 < bit2; - ) - ); - - result_type operator()(int a, int b) const { + VEX_DUAL_FUNCTOR(result_type, (int, a)(int, b), char bit1 = 1 & a; char bit2 = 1 & b; if (bit1 == bit2) return a < b; return bit1 < bit2; - } - - even_first_t() {} + ) } even_first; std::stable_sort(p.begin(), p.end(), [&](int i, int j) { @@ -108,16 +98,11 @@ BOOST_AUTO_TEST_CASE(sort_keys_tuple) struct less_t { typedef bool result_type; + less_t() {} - VEX_FUNCTION_V1(device, bool(int, float, int, float), - "return (prm1 == prm3) ? (prm2 < prm4) : (prm1 < prm3);" - ); - - result_type operator()(int a1, float a2, int b1, float b2) const { + VEX_DUAL_FUNCTOR(result_type, (int, a1)(float, a2)(int, b1)(float, b2), return (a1 == b1) ? (a2 < b2) : (a1 < b1); - } - - less_t() {} + ) } less; vex::sort(boost::fusion::vector_tie(keys1, keys2), less ); @@ -151,16 +136,11 @@ BOOST_AUTO_TEST_CASE(sort_keys_vals_tuple) struct less_t { typedef bool result_type; + less_t() {} - VEX_FUNCTION_V1(device, bool(int, float, int, float), - "return (prm1 == prm3) ? (prm2 < prm4) : (prm1 < prm3);" - ); - - result_type operator()(int a1, float a2, int b1, float b2) const { + VEX_DUAL_FUNCTOR(result_type, (int, a1)(float, a2)(int, b1)(float, b2), return (a1 == b1) ? (a2 < b2) : (a1 < b1); - } - - less_t() {} + ) } less; std::stable_sort(p.begin(), p.end(), [&](int i, int j) { diff --git a/tests/temporary.cpp b/tests/temporary.cpp index aff190996..1037f199f 100644 --- a/tests/temporary.cpp +++ b/tests/temporary.cpp @@ -14,7 +14,7 @@ BOOST_AUTO_TEST_CASE(temporary) vex::vector x(ctx, random_vector(n)); vex::vector y(ctx, n); - VEX_FUNCTION_V1(sqr, double(double), "return prm1 * prm1;"); + VEX_FUNCTION(double, sqr, (double, x), return x * x;); { // Deduce temporary type diff --git a/tests/vector_arithmetics.cpp b/tests/vector_arithmetics.cpp index 12667816d..2b6096b22 100644 --- a/tests/vector_arithmetics.cpp +++ b/tests/vector_arithmetics.cpp @@ -95,7 +95,7 @@ BOOST_AUTO_TEST_CASE(user_defined_functions) x = 1; y = 2; - VEX_FUNCTION_V1(greater, size_t(double, double), "return prm1 > prm2;"); + VEX_FUNCTION(size_t, greater, (double, x)(double, y), return x > y;); vex::Reductor sum(ctx); @@ -110,8 +110,8 @@ BOOST_AUTO_TEST_CASE(user_defined_functions_same_signature) x = 1; - VEX_FUNCTION_V1(times2, double(double), "return prm1 * 2;"); - VEX_FUNCTION_V1(times4, double(double), "return prm1 * 4;"); + VEX_FUNCTION(double, times2, (double, x), return x * 2;); + VEX_FUNCTION(double, times4, (double, x), return x * 4;); vex::Reductor sum(ctx); @@ -135,7 +135,7 @@ BOOST_AUTO_TEST_CASE(vector_values) { const size_t N = 1024; - VEX_FUNCTION_V1(make_int4, cl_int4(int), "return (int4)(prm1, prm1, prm1, prm1);"); + VEX_FUNCTION(cl_int4, make_int4, (int, x), return (int4)(x, x, x, x);); cl_int4 c = {{1, 2, 3, 4}}; @@ -153,8 +153,8 @@ BOOST_AUTO_TEST_CASE(nested_functions) { const size_t N = 1024; - VEX_FUNCTION_V1(f, int(int), "return 2 * prm1;"); - VEX_FUNCTION_V1(g, int(int), "return 3 * prm1;"); + VEX_FUNCTION(int, f, (int, x), return 2 * x;); + VEX_FUNCTION(int, g, (int, x), return 3 * x;); vex::vector x(ctx, N); @@ -175,7 +175,7 @@ BOOST_AUTO_TEST_CASE(custom_header) vex::push_program_header(ctx, "#define THE_ANSWER 42\n"); - VEX_FUNCTION_V1(answer, int(int), "return prm1 * THE_ANSWER;"); + VEX_FUNCTION(int, answer, (int, x), return x * THE_ANSWER;); x = answer(1); @@ -193,19 +193,11 @@ BOOST_AUTO_TEST_CASE(function_with_preamble) vex::vector x(ctx, random_vector(n)); vex::vector y(ctx, n); -#ifdef VEXCL_BACKEND_OPENCL - VEX_FUNCTION_V1_WITH_PREAMBLE(one, double(double), - "double sin2(double x) { return pow(sin(x), 2.0); }\n" - "double cos2(double x) { return pow(cos(x), 2.0); }\n", - "return sin2(prm1) + cos2(prm1);" - ); -#else - VEX_FUNCTION_V1_WITH_PREAMBLE(one, double(double), - "__device__ double sin2(double x) { return pow(sin(x), 2.0); }\n" - "__device__ double cos2(double x) { return pow(cos(x), 2.0); }\n", - "return sin2(prm1) + cos2(prm1);" + VEX_FUNCTION(double, sin2, (double, x), return pow(sin(x), 2.0);); + VEX_FUNCTION(double, cos2, (double, x), return pow(cos(x), 2.0);); + VEX_FUNCTION_D(double, one, (double, x), (sin2)(cos2), + return sin2(x) + cos2(x); ); -#endif y = one(x); @@ -214,79 +206,6 @@ BOOST_AUTO_TEST_CASE(function_with_preamble) }); } -BOOST_AUTO_TEST_CASE(function_v2) -{ - const size_t n = 1024; - - vex::vector x(ctx, random_vector(n)); - vex::vector y(ctx, random_vector(n)); - vex::vector z(ctx, n); - - { - VEX_FUNCTION(double, foo, (double, x)(double, y), - return (x - y) * (x + y); - ); - - z = foo(x, y); - - check_sample(x, y, z, [](size_t, double X, double Y, double Z) { - BOOST_CHECK_EQUAL(Z, (X - Y) * (X + Y)); - }); - } - - { - VEX_FUNCTION_S(double, foo, (double, x)(double, y), - "return (x - y) * (x + y);" - ); - - z = foo(x, y); - - check_sample(x, y, z, [](size_t, double X, double Y, double Z) { - BOOST_CHECK_EQUAL(Z, (X - Y) * (X + Y)); - }); - } - - { - VEX_FUNCTION(double, bar, (double, x), - double s = sin(x); - return s * s; - ); - VEX_FUNCTION(double, baz, (double, x), - double c = cos(x); - return c * c; - ); - VEX_FUNCTION_D(double, foo, (double, x)(double, y), (bar)(baz), - return bar(x - y) * baz(x + y); - ); - - z = foo(x, y); - - check_sample(x, y, z, [](size_t, double X, double Y, double Z) { - BOOST_CHECK_CLOSE(Z, pow(sin(X - Y), 2) * pow(cos(X + Y), 2), 1e-8); - }); - } - - { - VEX_FUNCTION(double, bar, (double, x), - double s = sin(x); - return s * s; - ); - VEX_FUNCTION(double, baz, (double, x), - double c = cos(x); - return c * c; - ); - VEX_FUNCTION_DS(double, foo, (double, x)(double, y), (bar)(baz), - "return bar(x - y) * baz(x + y);" - ); - - z = foo(x, y); - - check_sample(x, y, z, [](size_t, double X, double Y, double Z) { - BOOST_CHECK_CLOSE(Z, pow(sin(X - Y), 2) * pow(cos(X + Y), 2), 1e-8); - }); - } -} - BOOST_AUTO_TEST_CASE(ternary_operator) { const size_t n = 1024; diff --git a/tests/vector_pointer.cpp b/tests/vector_pointer.cpp index ee1451764..14abbff46 100644 --- a/tests/vector_pointer.cpp +++ b/tests/vector_pointer.cpp @@ -20,13 +20,11 @@ BOOST_AUTO_TEST_CASE(nbody) vex::vector x(queue, X); vex::vector y(queue, n); - VEX_FUNCTION_V1(nbody, double(size_t, size_t, double*), - VEX_STRINGIZE_SOURCE( - double sum = 0; - for(size_t i = 0; i < prm1; ++i) - if (i != prm2) sum += prm3[i]; - return sum; - ) + VEX_FUNCTION(double, nbody, (size_t, n)(size_t, j)(double*, x), + double sum = 0; + for(size_t i = 0; i < n; ++i) + if (i != j) sum += x[i]; + return sum; ); y = nbody(n, vex::element_index(), vex::raw_pointer(x)); diff --git a/vexcl/backend/opencl/fft/plan.hpp b/vexcl/backend/opencl/fft/plan.hpp index 66fd90989..4f3bba58e 100644 --- a/vexcl/backend/opencl/fft/plan.hpp +++ b/vexcl/backend/opencl/fft/plan.hpp @@ -188,8 +188,8 @@ struct plan { typedef typename cl_vector_of::type T2; - VEX_FUNCTION_V1(r2c, T2(Ts), "return (" + type_name() + ")(prm1, 0);"); - VEX_FUNCTION_V1(c2r, Ts(T2), "return prm1.x;"); + VEX_FUNCTION_S(T2, r2c, (Ts, v), "return (" + type_name() + ")(v, 0);"); + VEX_FUNCTION_S(Ts, c2r, (T2, v), "return v.x;"); const std::vector &queues; Planner planner; diff --git a/vexcl/reduce_by_key.hpp b/vexcl/reduce_by_key.hpp index 4f51a9536..e868e25c3 100644 --- a/vexcl/reduce_by_key.hpp +++ b/vexcl/reduce_by_key.hpp @@ -486,7 +486,7 @@ int reduce_by_key_sink( krn0(queue[0]); - VEX_FUNCTION_V1(plus, int(int, int), "return prm1 + prm2;"); + VEX_FUNCTION(int, plus, (int, x)(int, y), return x + y;); detail::scan(queue[0], offset, offset, 0, false, plus); /***** Kernel 1 *****/ @@ -580,8 +580,8 @@ int reduce_by_key( vector &ovals ) { - VEX_FUNCTION_V1(equal, bool(K, K), "return prm1 == prm2;"); - VEX_FUNCTION_V1(plus, V(V, V), "return prm1 + prm2;"); + VEX_FUNCTION(bool, equal, (K, x)(K, y), return x == y;); + VEX_FUNCTION(V, plus, (V, x)(V, y), return x + y;); return reduce_by_key(ikeys, ivals, okeys, ovals, equal, plus); } diff --git a/vexcl/scan.hpp b/vexcl/scan.hpp index 126d0fead..b09defdd9 100644 --- a/vexcl/scan.hpp +++ b/vexcl/scan.hpp @@ -424,7 +424,7 @@ void scan( /// Binary function object class whose call returns the result of adding its two arguments. template struct plus : std::plus { - VEX_FUNCTION_V1(device, T(T, T), "return prm1 + prm2;"); + VEX_FUNCTION(T, device, (T, x)(T, y), return x + y;); plus() {} }; diff --git a/vexcl/sort.hpp b/vexcl/sort.hpp index c659f66a9..02a47dbe1 100644 --- a/vexcl/sort.hpp +++ b/vexcl/sort.hpp @@ -2212,7 +2212,7 @@ void sort_by_key_sink(K &&keys, V &&vals, Comp comp) { */ template struct less : std::less { - VEX_FUNCTION_V1(device, bool(T, T), "return prm1 < prm2;"); + VEX_FUNCTION(bool, device, (T, x)(T, y), return x < y;); less() {} }; @@ -2225,7 +2225,7 @@ struct less : std::less { */ template struct less_equal : std::less_equal { - VEX_FUNCTION_V1(device, bool(T, T), "return prm1 <= prm2;"); + VEX_FUNCTION(bool, device, (T, x)(T, y), return x <= y;); less_equal() {} }; @@ -2238,7 +2238,7 @@ struct less_equal : std::less_equal { */ template struct greater : std::greater { - VEX_FUNCTION_V1(device, bool(T, T), "return prm1 > prm2;"); + VEX_FUNCTION(bool, device, (T, x)(T, y), return x > y;); greater() {} }; @@ -2251,7 +2251,7 @@ struct greater : std::greater { */ template struct greater_equal : std::greater_equal { - VEX_FUNCTION_V1(device, bool(T, T), "return prm1 >= prm2;"); + VEX_FUNCTION(bool, device, (T, x)(T, y), return x >= y;); greater_equal() {} };