Skip to content

Commit

Permalink
Switch to VEX_FUNCTION(v2) internally and in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ddemidov committed Apr 2, 2014
1 parent fc12854 commit e3b83bd
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 148 deletions.
4 changes: 2 additions & 2 deletions tests/deduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ BOOST_AUTO_TEST_CASE(user_functions)
vex::vector<double> x;
vex::vector<int> y;

VEX_FUNCTION_V1(f1, double(double), "return 42;");
VEX_FUNCTION_V1(f2, int(double, double), "return 42;");
VEX_FUNCTION(double, f1, (double, x), return 42;);
VEX_FUNCTION(int, f2, (double, x)(double, y), return 42;);

check<double>( f1(x) );
check<int> ( f2(x, y) );
Expand Down
7 changes: 5 additions & 2 deletions tests/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ BOOST_AUTO_TEST_CASE(kernel_generator_with_user_function)
sym_state sym_x(sym_state::VectorParameter, sym_state::Const);
sym_state sym_y(sym_state::VectorParameter);

VEX_FUNCTION_V1(sin2, double(double), "double s = sin(prm1); return s * s;");
VEX_FUNCTION(double, sin2, (double, x),
double s = sin(x);
return s * s;
);

sym_y = sin2(sym_x);

Expand Down Expand Up @@ -100,7 +103,7 @@ BOOST_AUTO_TEST_CASE(function_generator)
static std::string function_body = vex::generator::make_function(
body.str(), sym_x, sym_x);

VEX_FUNCTION_V1(rk2, double(double), function_body);
VEX_FUNCTION_S(double, rk2, (double, prm1), function_body);

std::vector<double> x = random_vector<double>(n);
vex::vector<double> X(ctx, x);
Expand Down
2 changes: 1 addition & 1 deletion tests/multivector_arithmetics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ BOOST_AUTO_TEST_CASE(user_defined_functions)
elem_t v1 = {{1, 2}};
elem_t v2 = {{2, 1}};

VEX_FUNCTION_V1(greater, size_t(double, double), "return prm1 > prm2;");
VEX_FUNCTION(size_t, greater, (double, x)(double, y), return x > y;);

x = v1;
y = v2;
Expand Down
8 changes: 4 additions & 4 deletions tests/reduce_by_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ BOOST_AUTO_TEST_CASE(rbk_tuple)
vex::vector<cl_long> okey2;
vex::vector<double> ovals;

VEX_FUNCTION_V1(equal, bool(cl_int, cl_long, cl_int, cl_long),
"return (prm1 == prm3) && (prm2 == prm4);"
VEX_FUNCTION(bool, equal, (cl_int, a1)(cl_long, a2)(cl_int, b1)(cl_long, b2),
return (a1 == b1) && (a2 == b2);
);

VEX_FUNCTION_V1(plus, double(double, double),
"return prm1 + prm2;"
VEX_FUNCTION(double, plus, (double, x)(double, y),
return x + y;
);

int num_keys = vex::reduce_by_key(
Expand Down
38 changes: 9 additions & 29 deletions tests/sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,14 @@ BOOST_AUTO_TEST_CASE(sort_keys_vals_custom_op)

struct even_first_t {
typedef bool result_type;
even_first_t() {}

VEX_FUNCTION_V1(device, bool(int, int),
VEX_STRINGIZE_SOURCE(
char bit1 = 1 & prm1;
char bit2 = 1 & prm2;
if (bit1 == bit2) return prm1 < prm2;
return bit1 < bit2;
)
);

result_type operator()(int a, int b) const {
VEX_DUAL_FUNCTOR(result_type, (int, a)(int, b),
char bit1 = 1 & a;
char bit2 = 1 & b;
if (bit1 == bit2) return a < b;
return bit1 < bit2;
}

even_first_t() {}
)
} even_first;

std::stable_sort(p.begin(), p.end(), [&](int i, int j) {
Expand Down Expand Up @@ -108,16 +98,11 @@ BOOST_AUTO_TEST_CASE(sort_keys_tuple)

struct less_t {
typedef bool result_type;
less_t() {}

VEX_FUNCTION_V1(device, bool(int, float, int, float),
"return (prm1 == prm3) ? (prm2 < prm4) : (prm1 < prm3);"
);

result_type operator()(int a1, float a2, int b1, float b2) const {
VEX_DUAL_FUNCTOR(result_type, (int, a1)(float, a2)(int, b1)(float, b2),
return (a1 == b1) ? (a2 < b2) : (a1 < b1);
}

less_t() {}
)
} less;

vex::sort(boost::fusion::vector_tie(keys1, keys2), less );
Expand Down Expand Up @@ -151,16 +136,11 @@ BOOST_AUTO_TEST_CASE(sort_keys_vals_tuple)

struct less_t {
typedef bool result_type;
less_t() {}

VEX_FUNCTION_V1(device, bool(int, float, int, float),
"return (prm1 == prm3) ? (prm2 < prm4) : (prm1 < prm3);"
);

result_type operator()(int a1, float a2, int b1, float b2) const {
VEX_DUAL_FUNCTOR(result_type, (int, a1)(float, a2)(int, b1)(float, b2),
return (a1 == b1) ? (a2 < b2) : (a1 < b1);
}

less_t() {}
)
} less;

std::stable_sort(p.begin(), p.end(), [&](int i, int j) {
Expand Down
2 changes: 1 addition & 1 deletion tests/temporary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ BOOST_AUTO_TEST_CASE(temporary)
vex::vector<double> x(ctx, random_vector<double>(n));
vex::vector<double> y(ctx, n);

VEX_FUNCTION_V1(sqr, double(double), "return prm1 * prm1;");
VEX_FUNCTION(double, sqr, (double, x), return x * x;);

{
// Deduce temporary type
Expand Down
103 changes: 11 additions & 92 deletions tests/vector_arithmetics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ BOOST_AUTO_TEST_CASE(user_defined_functions)
x = 1;
y = 2;

VEX_FUNCTION_V1(greater, size_t(double, double), "return prm1 > prm2;");
VEX_FUNCTION(size_t, greater, (double, x)(double, y), return x > y;);

vex::Reductor<size_t,vex::SUM> sum(ctx);

Expand All @@ -110,8 +110,8 @@ BOOST_AUTO_TEST_CASE(user_defined_functions_same_signature)

x = 1;

VEX_FUNCTION_V1(times2, double(double), "return prm1 * 2;");
VEX_FUNCTION_V1(times4, double(double), "return prm1 * 4;");
VEX_FUNCTION(double, times2, (double, x), return x * 2;);
VEX_FUNCTION(double, times4, (double, x), return x * 4;);

vex::Reductor<size_t,vex::SUM> sum(ctx);

Expand All @@ -135,7 +135,7 @@ BOOST_AUTO_TEST_CASE(vector_values)
{
const size_t N = 1024;

VEX_FUNCTION_V1(make_int4, cl_int4(int), "return (int4)(prm1, prm1, prm1, prm1);");
VEX_FUNCTION(cl_int4, make_int4, (int, x), return (int4)(x, x, x, x););

cl_int4 c = {{1, 2, 3, 4}};

Expand All @@ -153,8 +153,8 @@ BOOST_AUTO_TEST_CASE(nested_functions)
{
const size_t N = 1024;

VEX_FUNCTION_V1(f, int(int), "return 2 * prm1;");
VEX_FUNCTION_V1(g, int(int), "return 3 * prm1;");
VEX_FUNCTION(int, f, (int, x), return 2 * x;);
VEX_FUNCTION(int, g, (int, x), return 3 * x;);

vex::vector<int> x(ctx, N);

Expand All @@ -175,7 +175,7 @@ BOOST_AUTO_TEST_CASE(custom_header)

vex::push_program_header(ctx, "#define THE_ANSWER 42\n");

VEX_FUNCTION_V1(answer, int(int), "return prm1 * THE_ANSWER;");
VEX_FUNCTION(int, answer, (int, x), return x * THE_ANSWER;);

x = answer(1);

Expand All @@ -193,19 +193,11 @@ BOOST_AUTO_TEST_CASE(function_with_preamble)
vex::vector<double> x(ctx, random_vector<double>(n));
vex::vector<double> y(ctx, n);

#ifdef VEXCL_BACKEND_OPENCL
VEX_FUNCTION_V1_WITH_PREAMBLE(one, double(double),
"double sin2(double x) { return pow(sin(x), 2.0); }\n"
"double cos2(double x) { return pow(cos(x), 2.0); }\n",
"return sin2(prm1) + cos2(prm1);"
);
#else
VEX_FUNCTION_V1_WITH_PREAMBLE(one, double(double),
"__device__ double sin2(double x) { return pow(sin(x), 2.0); }\n"
"__device__ double cos2(double x) { return pow(cos(x), 2.0); }\n",
"return sin2(prm1) + cos2(prm1);"
VEX_FUNCTION(double, sin2, (double, x), return pow(sin(x), 2.0););
VEX_FUNCTION(double, cos2, (double, x), return pow(cos(x), 2.0););
VEX_FUNCTION_D(double, one, (double, x), (sin2)(cos2),
return sin2(x) + cos2(x);
);
#endif

y = one(x);

Expand All @@ -214,79 +206,6 @@ BOOST_AUTO_TEST_CASE(function_with_preamble)
});
}

BOOST_AUTO_TEST_CASE(function_v2)
{
const size_t n = 1024;

vex::vector<double> x(ctx, random_vector<double>(n));
vex::vector<double> y(ctx, random_vector<double>(n));
vex::vector<double> z(ctx, n);

{
VEX_FUNCTION(double, foo, (double, x)(double, y),
return (x - y) * (x + y);
);

z = foo(x, y);

check_sample(x, y, z, [](size_t, double X, double Y, double Z) {
BOOST_CHECK_EQUAL(Z, (X - Y) * (X + Y));
});
}

{
VEX_FUNCTION_S(double, foo, (double, x)(double, y),
"return (x - y) * (x + y);"
);

z = foo(x, y);

check_sample(x, y, z, [](size_t, double X, double Y, double Z) {
BOOST_CHECK_EQUAL(Z, (X - Y) * (X + Y));
});
}

{
VEX_FUNCTION(double, bar, (double, x),
double s = sin(x);
return s * s;
);
VEX_FUNCTION(double, baz, (double, x),
double c = cos(x);
return c * c;
);
VEX_FUNCTION_D(double, foo, (double, x)(double, y), (bar)(baz),
return bar(x - y) * baz(x + y);
);

z = foo(x, y);

check_sample(x, y, z, [](size_t, double X, double Y, double Z) {
BOOST_CHECK_CLOSE(Z, pow(sin(X - Y), 2) * pow(cos(X + Y), 2), 1e-8);
});
}

{
VEX_FUNCTION(double, bar, (double, x),
double s = sin(x);
return s * s;
);
VEX_FUNCTION(double, baz, (double, x),
double c = cos(x);
return c * c;
);
VEX_FUNCTION_DS(double, foo, (double, x)(double, y), (bar)(baz),
"return bar(x - y) * baz(x + y);"
);

z = foo(x, y);

check_sample(x, y, z, [](size_t, double X, double Y, double Z) {
BOOST_CHECK_CLOSE(Z, pow(sin(X - Y), 2) * pow(cos(X + Y), 2), 1e-8);
});
}
}

BOOST_AUTO_TEST_CASE(ternary_operator)
{
const size_t n = 1024;
Expand Down
12 changes: 5 additions & 7 deletions tests/vector_pointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ BOOST_AUTO_TEST_CASE(nbody)
vex::vector<double> x(queue, X);
vex::vector<double> y(queue, n);

VEX_FUNCTION_V1(nbody, double(size_t, size_t, double*),
VEX_STRINGIZE_SOURCE(
double sum = 0;
for(size_t i = 0; i < prm1; ++i)
if (i != prm2) sum += prm3[i];
return sum;
)
VEX_FUNCTION(double, nbody, (size_t, n)(size_t, j)(double*, x),
double sum = 0;
for(size_t i = 0; i < n; ++i)
if (i != j) sum += x[i];
return sum;
);

y = nbody(n, vex::element_index(), vex::raw_pointer(x));
Expand Down
4 changes: 2 additions & 2 deletions vexcl/backend/opencl/fft/plan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ struct plan {

typedef typename cl_vector_of<Ts, 2>::type T2;

VEX_FUNCTION_V1(r2c, T2(Ts), "return (" + type_name<T2>() + ")(prm1, 0);");
VEX_FUNCTION_V1(c2r, Ts(T2), "return prm1.x;");
VEX_FUNCTION_S(T2, r2c, (Ts, v), "return (" + type_name<T2>() + ")(v, 0);");
VEX_FUNCTION_S(Ts, c2r, (T2, v), "return v.x;");

const std::vector<backend::command_queue> &queues;
Planner planner;
Expand Down
6 changes: 3 additions & 3 deletions vexcl/reduce_by_key.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ int reduce_by_key_sink(

krn0(queue[0]);

VEX_FUNCTION_V1(plus, int(int, int), "return prm1 + prm2;");
VEX_FUNCTION(int, plus, (int, x)(int, y), return x + y;);
detail::scan(queue[0], offset, offset, 0, false, plus);

/***** Kernel 1 *****/
Expand Down Expand Up @@ -580,8 +580,8 @@ int reduce_by_key(
vector<V> &ovals
)
{
VEX_FUNCTION_V1(equal, bool(K, K), "return prm1 == prm2;");
VEX_FUNCTION_V1(plus, V(V, V), "return prm1 + prm2;");
VEX_FUNCTION(bool, equal, (K, x)(K, y), return x == y;);
VEX_FUNCTION(V, plus, (V, x)(V, y), return x + y;);
return reduce_by_key(ikeys, ivals, okeys, ovals, equal, plus);
}

Expand Down
2 changes: 1 addition & 1 deletion vexcl/scan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ void scan(
/// Binary function object class whose call returns the result of adding its two arguments.
template <typename T>
struct plus : std::plus<T> {
VEX_FUNCTION_V1(device, T(T, T), "return prm1 + prm2;");
VEX_FUNCTION(T, device, (T, x)(T, y), return x + y;);

plus() {}
};
Expand Down
8 changes: 4 additions & 4 deletions vexcl/sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2212,7 +2212,7 @@ void sort_by_key_sink(K &&keys, V &&vals, Comp comp) {
*/
template <typename T>
struct less : std::less<T> {
VEX_FUNCTION_V1(device, bool(T, T), "return prm1 < prm2;");
VEX_FUNCTION(bool, device, (T, x)(T, y), return x < y;);

less() {}
};
Expand All @@ -2225,7 +2225,7 @@ struct less : std::less<T> {
*/
template <typename T>
struct less_equal : std::less_equal<T> {
VEX_FUNCTION_V1(device, bool(T, T), "return prm1 <= prm2;");
VEX_FUNCTION(bool, device, (T, x)(T, y), return x <= y;);

less_equal() {}
};
Expand All @@ -2238,7 +2238,7 @@ struct less_equal : std::less_equal<T> {
*/
template <typename T>
struct greater : std::greater<T> {
VEX_FUNCTION_V1(device, bool(T, T), "return prm1 > prm2;");
VEX_FUNCTION(bool, device, (T, x)(T, y), return x > y;);

greater() {}
};
Expand All @@ -2251,7 +2251,7 @@ struct greater : std::greater<T> {
*/
template <typename T>
struct greater_equal : std::greater_equal<T> {
VEX_FUNCTION_V1(device, bool(T, T), "return prm1 >= prm2;");
VEX_FUNCTION(bool, device, (T, x)(T, y), return x >= y;);

greater_equal() {}
};
Expand Down

0 comments on commit e3b83bd

Please sign in to comment.