From 8071d4969d6c11bad7942206fe0f64c6ea81a688 Mon Sep 17 00:00:00 2001 From: Matthew Michel <106704043+mmichel11@users.noreply.github.com> Date: Mon, 30 Oct 2023 03:30:13 -0500 Subject: [PATCH] [SYCLomatic #1165] Add test for tagged_pointer (#427) Signed-off-by: Matthew Michel --- help_function/help_function.xml | 2 + .../src/onedpl_test_sys_tag_utils.cpp | 115 +++++++++++++ .../src/onedpl_test_tagged_pointer.cpp | 151 ++++++++++++++++++ 3 files changed, 268 insertions(+) create mode 100644 help_function/src/onedpl_test_sys_tag_utils.cpp create mode 100644 help_function/src/onedpl_test_tagged_pointer.cpp diff --git a/help_function/help_function.xml b/help_function/help_function.xml index aa40de8e2..fb3a8eeff 100644 --- a/help_function/help_function.xml +++ b/help_function/help_function.xml @@ -125,7 +125,9 @@ + + diff --git a/help_function/src/onedpl_test_sys_tag_utils.cpp b/help_function/src/onedpl_test_sys_tag_utils.cpp new file mode 100644 index 000000000..0cc19cc46 --- /dev/null +++ b/help_function/src/onedpl_test_sys_tag_utils.cpp @@ -0,0 +1,115 @@ +// ====------ onedpl_test_sys_tag_utils.cpp---------- -*- C++ -* ----===//// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// +// ===----------------------------------------------------------------------===// + +#include + +#include "dpct/dpct.hpp" +#include "dpct/dpl_utils.hpp" + +#include + +template +int ASSERT_EQUAL(String msg, _T1 &&X, _T2 &&Y) { + if (X != Y) { + std::cout << "FAIL: " << msg << " - (" << X << "," << Y << ")" << std::endl; + return 1; + } else { + std::cout << "PASS: " << msg << std::endl; + return 0; + } +} + +int test_internal_policy_conversion(void) { + int num_fails = 0; + + using seq_tag = dpct::internal::policy_or_tag_to_tag::type; + using unseq_tag = dpct::internal::policy_or_tag_to_tag_t; + using par_tag = dpct::internal::policy_or_tag_to_tag_t; + using par_unseq_tag = dpct::internal::policy_or_tag_to_tag_t; + using dev_tag = dpct::internal::policy_or_tag_to_tag_t; + using reflect_host_tag = + dpct::internal::policy_or_tag_to_tag_t; + using reflect_dev_tag = + dpct::internal::policy_or_tag_to_tag_t; + + num_fails += ASSERT_EQUAL("seq policy tag conversion", + std::is_same_v, true); + num_fails += + ASSERT_EQUAL("unseq policy tag conversion", + std::is_same_v, true); + num_fails += ASSERT_EQUAL("par policy tag conversion", + std::is_same_v, true); + num_fails += + ASSERT_EQUAL("par_unseq policy tag conversion", + std::is_same_v, true); + num_fails += + ASSERT_EQUAL("dpcpp_default policy tag conversion", + std::is_same_v, true); + num_fails += + ASSERT_EQUAL("host tag reflection", + std::is_same_v, true); + num_fails += + ASSERT_EQUAL("device tag reflection", + std::is_same_v, true); + + return num_fails; +} + +int test_internal_is_host_policy_or_tag(void) { + int num_fails = 0; + + constexpr bool seq_is_host_tag = + dpct::internal::is_host_policy_or_tag::value; + constexpr bool unseq_is_host_tag = + dpct::internal::is_host_policy_or_tag_v; + constexpr bool par_is_host_tag = + dpct::internal::is_host_policy_or_tag_v; + constexpr bool par_unseq_is_host_tag = + dpct::internal::is_host_policy_or_tag_v; + constexpr bool dev_is_host_tag = + dpct::internal::is_host_policy_or_tag_v; + constexpr bool host_tag_is_host_tag = + dpct::internal::is_host_policy_or_tag_v; + constexpr bool dev_tag_is_host_tag = + dpct::internal::is_host_policy_or_tag_v; + + num_fails += ASSERT_EQUAL("seq policy is host", seq_is_host_tag, true); + num_fails += ASSERT_EQUAL("unseq policy is host", unseq_is_host_tag, true); + num_fails += ASSERT_EQUAL("par policy is host", par_is_host_tag, true); + num_fails += + ASSERT_EQUAL("par_unseq policy is host", par_unseq_is_host_tag, true); + num_fails += + ASSERT_EQUAL("dpcpp_default policy is host", dev_is_host_tag, false); + num_fails += ASSERT_EQUAL("host tag is host", host_tag_is_host_tag, true); + num_fails += ASSERT_EQUAL("device tag is host", dev_tag_is_host_tag, false); + + return num_fails; +} + +int main() { + int failed_tests = test_internal_policy_conversion(); + failed_tests += test_internal_is_host_policy_or_tag(); + + std::cout << std::endl + << failed_tests << " failing test(s) detected." << std::endl; + if (failed_tests == 0) { + return 0; + } + return 1; +} diff --git a/help_function/src/onedpl_test_tagged_pointer.cpp b/help_function/src/onedpl_test_tagged_pointer.cpp new file mode 100644 index 000000000..3a35f97f8 --- /dev/null +++ b/help_function/src/onedpl_test_tagged_pointer.cpp @@ -0,0 +1,151 @@ +// ====------ onedpl_test_tagged_pointer.cpp---------- -*- C++ -* ----===//// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// +// ===----------------------------------------------------------------------===// + +// Adapted from onedpl_test_device_ptr.cpp + +#include +#include +#include + +#include "dpct/dpct.hpp" +#include "dpct/dpl_utils.hpp" + +#include + +// Used to test operator-> behavior +struct integer_wrapper { + int val; +}; + +template +int ASSERT_EQUAL(String msg, _T1 &&X, _T2 &&Y) { + if (X != Y) { + std::cout << "FAIL: " << msg << " - (" << X << "," << Y << ")" << std::endl; + return 1; + } else { + std::cout << "PASS: " << msg << std::endl; + return 0; + } +} + +template int test_tagged_pointer_manipulation(void) { + int failing_tests = 0; + constexpr ::std::size_t n = 20; + SystemTag system; + ::std::string sys = ::std::is_same_v + ? "dpct::host_sys_tag" + : "dpct::device_sys_tag"; + std::string int_ptr_name = "dpct::tagged_pointer<" + sys + ", int>"; + std::string int_wrapper_name = + "dpct::tagged_pointer<" + sys + ", integer_wrapper>"; + std::string void_ptr_name = "dpct::tagged_pointer<" + sys + ", void>"; + + dpct::tagged_pointer void_ptr_beg = + dpct::malloc(system, sizeof(int) * n); + dpct::tagged_pointer int_ptr_beg = + static_cast>(void_ptr_beg); + + dpct::tagged_pointer void_ptr_beg2 = + static_cast>(int_ptr_beg); + failing_tests += ASSERT_EQUAL(void_ptr_name + " conversion operator", + void_ptr_beg == void_ptr_beg2, true); + + dpct::tagged_pointer int_ptr_end = int_ptr_beg + n; + failing_tests += ASSERT_EQUAL( + int_ptr_name + " add operator", + static_cast(int_ptr_end) - static_cast(int_ptr_beg), n); + + dpct::tagged_pointer expect_beg = int_ptr_end - n; + failing_tests += ASSERT_EQUAL(int_ptr_name + " subtract operator", + int_ptr_beg == expect_beg, true); + + failing_tests += ASSERT_EQUAL(int_ptr_name + " difference operator", + int_ptr_end - int_ptr_beg, n); + + expect_beg++; + failing_tests += ASSERT_EQUAL(int_ptr_name + " postfix increment", + (int_ptr_beg + 1) == expect_beg, true); + + expect_beg--; + failing_tests += ASSERT_EQUAL(int_ptr_name + " postfix decrement", + int_ptr_beg == expect_beg, true); + + ++expect_beg; + failing_tests += ASSERT_EQUAL(int_ptr_name + " prefix increment", + (int_ptr_beg + 1) == expect_beg, true); + --expect_beg; + failing_tests += ASSERT_EQUAL(int_ptr_name + " prefix decrement", + int_ptr_beg == expect_beg, true); + + expect_beg += 2; + failing_tests += ASSERT_EQUAL(int_ptr_name + " addition assignment", + (int_ptr_beg + 2) == expect_beg, true); + + expect_beg -= 2; + failing_tests += ASSERT_EQUAL(int_ptr_name + " subtraction assignment", + int_ptr_beg == expect_beg, true); + + // Test conversion to base pointer + int *int_ptr_beg_raw = int_ptr_beg; + int *int_ptr_end_raw = int_ptr_end; + failing_tests += ASSERT_EQUAL(int_ptr_name + " conversion to int*", + int_ptr_end_raw - int_ptr_beg_raw, n); + + // device allocations use malloc_shared so this is safe + *int_ptr_beg = 4; + failing_tests += ASSERT_EQUAL(int_ptr_name + " dereference operator", + *int_ptr_beg == 4, true); + int_ptr_beg[1] = 2; + failing_tests += ASSERT_EQUAL(int_ptr_name + " subscript operator", + int_ptr_beg[1] == 2, true); + + dpct::tagged_pointer int_wrapper_beg = + dpct::malloc(system, 1); + int_wrapper_beg->val = 5; + failing_tests += ASSERT_EQUAL(int_wrapper_name + " arrow operator", + (*int_wrapper_beg).val == 5, true); + + dpct::free(system, void_ptr_beg); + dpct::free(system, int_wrapper_beg); + return failing_tests; +} + +template +int test_tagged_pointer_iteration(Policy policy, std::string test_name) { + constexpr ::std::size_t n = 1024; + int return_fail_code = 0; + + auto ptr_beg = dpct::malloc(policy, n); + auto ptr_end = ptr_beg + n; + + std::fill(policy, ptr_beg, ptr_end, 99); + int result = oneapi::dpl::reduce(policy, ptr_beg, ptr_beg + n); + return_fail_code += + ASSERT_EQUAL(test_name + " reduce algorithm test", result, n * 99); + dpct::free(policy, ptr_beg); + return return_fail_code; +} + +int main() { + int failed_tests = test_tagged_pointer_manipulation(); + failed_tests += test_tagged_pointer_manipulation(); + + failed_tests += test_tagged_pointer_iteration( + dpl::execution::seq, "dpct::tagged_pointer"); + failed_tests += test_tagged_pointer_iteration( + dpl::execution::make_device_policy(dpct::get_default_queue()), + "dpct::tagged_pointer"); + + std::cout << std::endl + << failed_tests << " failing test(s) detected." << std::endl; + if (failed_tests == 0) { + return 0; + } + return 1; +}