From a9378fbbe9f62c65a174df6fc3c81b8224a33755 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Scipione?= Date: Fri, 23 Aug 2024 10:53:24 +0200 Subject: [PATCH] Update rotmg interface to handle issue in OpenCL CPU support (#532) * Update rotmg implementation * Add conditional check on mem_type in use * Add dependencies to copy_y1 operation * Add new interface to library and add related tests. --- src/interface/blas1/rotmg.cpp.in | 11 + src/interface/blas1_interface.hpp | 33 ++- test/unittest/blas1/blas1_rotmg_test.cpp | 272 +++++++++++++---------- 3 files changed, 192 insertions(+), 124 deletions(-) diff --git a/src/interface/blas1/rotmg.cpp.in b/src/interface/blas1/rotmg.cpp.in index 75d346b4b..557530738 100644 --- a/src/interface/blas1/rotmg.cpp.in +++ b/src/interface/blas1/rotmg.cpp.in @@ -72,11 +72,22 @@ template typename SB_Handle::event_t _rotmg( BufferIterator<${DATA_TYPE}> _y1, BufferIterator<${DATA_TYPE}> _param, const typename SB_Handle::event_t& dependencies); +template typename SB_Handle::event_t _rotmg( + SB_Handle& sb_handle, BufferIterator<${DATA_TYPE}> _d1, + BufferIterator<${DATA_TYPE}> _d2, BufferIterator<${DATA_TYPE}> _x1, + ${DATA_TYPE} _y1, BufferIterator<${DATA_TYPE}> _param, + const typename SB_Handle::event_t& dependencies); + #ifdef SB_ENABLE_USM template typename SB_Handle::event_t _rotmg( SB_Handle& sb_handle, ${DATA_TYPE} * _d1, ${DATA_TYPE} * _d2, ${DATA_TYPE} * _x1, ${DATA_TYPE} * _y1, ${DATA_TYPE} * _param, const typename SB_Handle::event_t& dependencies); + +template typename SB_Handle::event_t _rotmg( + SB_Handle& sb_handle, ${DATA_TYPE} * _d1, ${DATA_TYPE} * _d2, + ${DATA_TYPE} * _x1, ${DATA_TYPE} _y1, ${DATA_TYPE} * _param, + const typename SB_Handle::event_t& dependencies); #endif } // namespace internal diff --git a/src/interface/blas1_interface.hpp b/src/interface/blas1_interface.hpp index b9ee55aab..02f35e06e 100644 --- a/src/interface/blas1_interface.hpp +++ b/src/interface/blas1_interface.hpp @@ -811,14 +811,37 @@ typename sb_handle_t::event_t _rotmg( auto d1_view = make_vector_view(_d1, inc, vector_size); auto d2_view = make_vector_view(_d2, inc, vector_size); auto x1_view = make_vector_view(_x1, inc, vector_size); - auto y1_view = make_vector_view(_y1, inc, vector_size); auto param_view = make_vector_view(_param, inc, param_size); - auto operation = - Rotmg(d1_view, d2_view, x1_view, y1_view, param_view); - auto ret = sb_handle.execute(operation, _dependencies); + if constexpr (std::is_arithmetic_v) { + constexpr helper::AllocType mem_type = std::is_pointer_v + ? helper::AllocType::usm + : helper::AllocType::buffer; + auto _y1_tmp = blas::helper::allocate( + 1, sb_handle.get_queue()); - return ret; + auto copy_y1 = blas::helper::copy_to_device(sb_handle.get_queue(), &_y1, + _y1_tmp, 1, _dependencies); + + auto y1_view = make_vector_view(_y1_tmp, inc, vector_size); + auto operation = Rotmg(d1_view, d2_view, x1_view, + y1_view, param_view); + + auto operator_event = + sb_handle.execute(operation, typename sb_handle_t::event_t{copy_y1}); + if constexpr (mem_type != helper::AllocType::buffer) { + // This wait is necessary to free the temporary memory created above and + // avoiding the host_task + operator_event[0].wait(); + sycl::free(_y1_tmp, sb_handle.get_queue()); + } + return operator_event; + } else { + auto y1_view = make_vector_view(_y1, inc, vector_size); + auto operation = Rotmg(d1_view, d2_view, x1_view, + y1_view, param_view); + return sb_handle.execute(operation, _dependencies); + } } /** diff --git a/test/unittest/blas1/blas1_rotmg_test.cpp b/test/unittest/blas1/blas1_rotmg_test.cpp index b037ba76d..604e39fe9 100644 --- a/test/unittest/blas1/blas1_rotmg_test.cpp +++ b/test/unittest/blas1/blas1_rotmg_test.cpp @@ -25,7 +25,8 @@ #include "blas_test.hpp" -template +template struct RotmgTest { /* Magic numbers used by the rotmg algorithm */ static constexpr scalar_t gamma = static_cast(4096.0); @@ -57,8 +58,8 @@ struct RotmgTest { void validate_with_rotm(); }; -template -void RotmgTest::run_portblas_rotmg() { +template +void RotmgTest::run_portblas_rotmg() { auto q = make_queue(); blas::SB_Handle sb_handle(q); @@ -67,38 +68,48 @@ void RotmgTest::run_portblas_rotmg() { auto device_d1 = helper::allocate(1, q); auto device_d2 = helper::allocate(1, q); auto device_x1 = helper::allocate(1, q); - auto device_y1 = helper::allocate(1, q); + decltype(device_x1) device_y1; auto device_param = helper::allocate(param_size, q); auto copy_d1 = helper::copy_to_device(q, &sycl_out.d1, device_d1, 1); auto copy_d2 = helper::copy_to_device(q, &sycl_out.d2, device_d2, 1); auto copy_x1 = helper::copy_to_device(q, &sycl_out.x1, device_x1, 1); - auto copy_y1 = helper::copy_to_device(q, &sycl_out.y1, device_y1, 1); auto copy_params = helper::copy_to_device(q, sycl_out.param.data(), device_param, param_size); - auto rotmg_event = - _rotmg(sb_handle, device_d1, device_d2, device_x1, device_y1, - device_param, {copy_d1, copy_d2, copy_x1, copy_y1, copy_params}); - sb_handle.wait(rotmg_event); + if constexpr (is_pointer) { + device_y1 = helper::allocate(1, q); + auto copy_y1 = helper::copy_to_device(q, &sycl_out.y1, device_y1, 1); + + auto rotmg_event = + _rotmg(sb_handle, device_d1, device_d2, device_x1, device_y1, + device_param, {copy_d1, copy_d2, copy_x1, copy_y1, copy_params}); + sb_handle.wait(rotmg_event); + } else { + auto rotmg_event = + _rotmg(sb_handle, device_d1, device_d2, device_x1, sycl_out.y1, + device_param, {copy_d1, copy_d2, copy_x1, copy_params}); + sb_handle.wait(rotmg_event); + } auto event1 = helper::copy_to_host(q, device_d1, &sycl_out.d1, 1); auto event2 = helper::copy_to_host(q, device_d2, &sycl_out.d2, 1); auto event3 = helper::copy_to_host(q, device_x1, &sycl_out.x1, 1); - auto event4 = helper::copy_to_host(q, device_y1, &sycl_out.y1, 1); - auto event5 = + auto event4 = helper::copy_to_host(q, device_param, sycl_out.param.data(), param_size); - sb_handle.wait({event1, event2, event3, event4, event5}); + sb_handle.wait({event1, event2, event3, event4}); helper::deallocate(device_d1, q); helper::deallocate(device_d2, q); helper::deallocate(device_x1, q); - helper::deallocate(device_y1, q); helper::deallocate(device_param, q); + if constexpr (is_pointer) { + helper::deallocate(device_y1, q); + } } -template -void RotmgTest::validate_with_reference() { +template +void RotmgTest::validate_with_reference() { scalar_t d1_ref = input.d1; scalar_t d2_ref = input.d2; scalar_t x1_ref = input.x1; @@ -129,8 +140,7 @@ void RotmgTest::validate_with_reference() { const bool isAlmostEqual = utils::almost_equal(sycl_out.d1, d1_ref) && utils::almost_equal(sycl_out.d2, d2_ref) && - utils::almost_equal(sycl_out.x1, x1_ref) && - utils::almost_equal(sycl_out.y1, y1_ref); + utils::almost_equal(sycl_out.x1, x1_ref); ASSERT_TRUE(isAlmostEqual); /* Validate param */ @@ -175,8 +185,8 @@ void RotmgTest::validate_with_reference() { * x1_output * sqrt(d1_output) = [ h11 h12 ] * [ x1_input] * 0.0 * sqrt(d2_output) [ h21 h22 ] [ y1_input] */ -template -void RotmgTest::validate_with_rotm() { +template +void RotmgTest::validate_with_rotm() { if (sycl_out.param[0] == 2 || sycl_out.d2 < 0) { return; } @@ -200,9 +210,9 @@ void RotmgTest::validate_with_rotm() { template using combination_t = - std::tuple; + std::tuple; -template +template void run_test(const combination_t combi) { std::string alloc; scalar_t d1_input; @@ -210,11 +220,13 @@ void run_test(const combination_t combi) { scalar_t x1_input; scalar_t y1_input; bool will_overflow; + bool is_pointer_unused; - std::tie(alloc, d1_input, d2_input, x1_input, y1_input, will_overflow) = - combi; + std::tie(alloc, d1_input, d2_input, x1_input, y1_input, will_overflow, + is_pointer_unused) = combi; - RotmgTest test{d1_input, d2_input, x1_input, y1_input}; + RotmgTest test{d1_input, d2_input, x1_input, + y1_input}; test.run_portblas_rotmg(); /* Do not test with things that might overflow or underflow. Results will @@ -233,18 +245,25 @@ void run_test(const combination_t combi) { scalar_t x1_input; scalar_t y1_input; bool will_overflow; + bool is_pointer; - std::tie(alloc, d1_input, d2_input, x1_input, y1_input, will_overflow) = - combi; + std::tie(alloc, d1_input, d2_input, x1_input, y1_input, will_overflow, + is_pointer) = combi; if (alloc == "usm") { // usm alloc #ifdef SB_ENABLE_USM - run_test(combi); + if (is_pointer) + run_test(combi); + else + run_test(combi); #else GTEST_SKIP(); #endif } else { // buffer alloc - run_test(combi); + if (is_pointer) + run_test(combi); + else + run_test(combi); } } @@ -290,155 +309,170 @@ scalar_t scale_up_gen() { } /* This tests try to cover every code path of the rotmg algorithm */ -#define INSTANTIATE_ROTMG_TESTS(NAME, C) \ +#define INSTANTIATE_ROTMG_TESTS(NAME, C, IS_POINTER) \ template \ const auto NAME = ::testing:: \ Values(/* d1 < 0 */ \ std::make_tuple(C, -2.5, p_gen(), r_gen(), \ - r_gen(), \ - false), /* Input point (c, 0) */ \ + r_gen(), false, \ + IS_POINTER), /* Input point (c, 0) */ \ std::make_tuple(C, p_gen(), p_gen(), \ - r_gen(), 0.0, \ - false), /* Input point (c, 0) && d2 == 0 */ \ - std::make_tuple(C, p_gen(), 0.0, r_gen(), \ - 0.0, false), /* Input point (c, 0) && d2 == 0 */ \ + r_gen(), 0.0, false, \ + IS_POINTER), /* Input point (c, 0) && d2 == 0 */ \ std::make_tuple(C, p_gen(), 0.0, r_gen(), \ - r_gen(), \ - false), /* Input point (c, 0) and big numbers \ - (test that no rescaling happened) */ \ - std::make_tuple(C, scale_up_gen(), \ - scale_up_gen(), \ - scale_up_gen(), 0.0, false), \ + r_gen(), false, \ + IS_POINTER), /* Input point (c, 0) and big \ + numbers (test that no rescaling happened) */ \ + std::make_tuple( \ + C, scale_up_gen(), scale_up_gen(), \ + scale_up_gen(), 0.0, false, IS_POINTER), \ std::make_tuple(C, scale_down_gen(), \ scale_down_gen(), \ - scale_down_gen(), 0.0, \ - false), /* Input point (0, c) */ \ + scale_down_gen(), 0.0, false, \ + IS_POINTER), /* Input point (0, c) */ \ std::make_tuple(C, p_gen(), p_gen(), 0.0, \ - r_gen(), \ - false), /* Input point (0, c) && d1 == 0 */ \ + r_gen(), false, \ + IS_POINTER), /* Input point (0, c) && d1 == 0 */ \ std::make_tuple(C, 0.0, p_gen(), 0.0, \ - r_gen(), \ - false), /* Input point (0, c) && d2 == 0 */ \ + r_gen(), false, \ + IS_POINTER), /* Input point (0, c) && d2 == 0 */ \ std::make_tuple(C, p_gen(), 0.0, 0.0, \ - r_gen(), \ - false), /* Input point (0, c) && d2 < 0 */ \ - std::make_tuple(C, p_gen(), -3.4, 0.0, \ - r_gen(), \ - false), /* Input point (0, c) && rescaling */ \ + r_gen(), false, \ + IS_POINTER), /* Input point (0, c) && d2 < 0 */ \ + std::make_tuple( \ + C, p_gen(), -3.4, 0.0, r_gen(), false, \ + IS_POINTER), /* Input point (0, c) && rescaling */ \ std::make_tuple(C, p_gen(), scale_up_gen(), \ - 0.0, r_gen(), false), \ + 0.0, r_gen(), false, IS_POINTER), \ std::make_tuple(C, p_gen(), scale_down_gen(), \ - 0.0, r_gen(), false), \ + 0.0, r_gen(), false, IS_POINTER), \ std::make_tuple(C, scale_up_gen(), p_gen(), \ - 0.0, r_gen(), false), \ + 0.0, r_gen(), false, IS_POINTER), \ std::make_tuple(C, scale_down_gen(), p_gen(), \ - 0.0, r_gen(), false), /* d1 == 0 */ \ + 0.0, r_gen(), false, \ + IS_POINTER), /* d1 == 0 */ \ std::make_tuple(C, 0.0, p_gen(), r_gen(), \ - r_gen(), \ - false), /* d1 == 0 && d2 < 0 */ \ + r_gen(), false, \ + IS_POINTER), /* d1 == 0 && d2 < 0 */ \ + std::make_tuple( \ + C, 0.0, -3.4, r_gen(), r_gen(), false, \ + IS_POINTER), /* d1 * x1 > d2 * y1 (i.e. abs_c > abs_s) */ \ + std::make_tuple(C, 4.0, 2.1, 3.4, 1.5, false, IS_POINTER), \ + std::make_tuple(C, 4.0, 1.5, -3.4, 2.1, false, IS_POINTER), \ + std::make_tuple(C, 4.0, -1.5, 3.4, 2.1, false, IS_POINTER), \ + std::make_tuple(C, 4.0, -1.5, 3.4, -2.1, false, IS_POINTER), \ std::make_tuple( \ - C, 0.0, -3.4, r_gen(), r_gen(), \ - false), /* d1 * x1 > d2 * y1 (i.e. abs_c > abs_s) */ \ - std::make_tuple(C, 4.0, 2.1, 3.4, 1.5, false), \ - std::make_tuple(C, 4.0, 1.5, -3.4, 2.1, false), \ - std::make_tuple(C, 4.0, -1.5, 3.4, 2.1, false), \ - std::make_tuple(C, 4.0, -1.5, 3.4, -2.1, false), \ - std::make_tuple(C, 4.0, -1.5, -3.4, -2.1, \ - false), /* d1 * x1 > d2 * y1 (i.e. abs_c > abs_s) \ - && rescaling */ \ + C, 4.0, -1.5, -3.4, -2.1, false, \ + IS_POINTER), /* d1 * x1 > d2 * y1 (i.e. abs_c > abs_s) \ + && rescaling */ \ std::make_tuple(C, scale_down_gen(), 2.1, 3.4, 1.5, \ - false), \ + false, IS_POINTER), \ std::make_tuple(C, scale_down_gen(), 2.1, \ - scale_down_gen(), 1.5, false), \ + scale_down_gen(), 1.5, false, \ + IS_POINTER), \ std::make_tuple(C, scale_up_gen(), 2.1, \ - scale_down_gen(), 1.5, false), \ - std::make_tuple(C, scale_down_gen(), 2.1, \ - scale_up_gen(), 1.5, \ - false), /* d1 * x1 > d2 * y1 (i.e. abs_c > abs_s) \ - && Underflow */ \ + scale_down_gen(), 1.5, false, \ + IS_POINTER), \ + std::make_tuple( \ + C, scale_down_gen(), 2.1, scale_up_gen(), \ + 1.5, false, \ + IS_POINTER), /* d1 * x1 > d2 * y1 (i.e. abs_c > abs_s) \ + && Underflow */ \ std::make_tuple(C, 0.01, 0.01, \ std::numeric_limits::min(), \ - std::numeric_limits::min(), \ - true), /* d1 * x1 > d2 * y1 && Overflow */ \ + std::numeric_limits::min(), true, \ + IS_POINTER), /* d1 * x1 > d2 * y1 && Overflow */ \ std::make_tuple( \ C, std::numeric_limits::max(), \ - std::numeric_limits::max(), 0.01, 0.01, \ - true), /* d1 * x1 <= d2 * y1 (i.e. abs_c <= abs_s) */ \ - std::make_tuple(C, 2.1, 4.0, 1.5, 3.4, false), \ - std::make_tuple(C, 2.1, 4.0, -1.5, 3.4, false), \ - std::make_tuple(C, 2.1, -4.0, 1.5, 3.4, false), \ - std::make_tuple(C, 2.1, -4.0, 1.5, -3.4, false), \ - std::make_tuple(C, 2.1, -4.0, -1.5, -3.4, \ - false), /* d1 * x1 <= d2 * y1 (i.e. abs_c <= \ - abs_s) && rescaling */ \ + std::numeric_limits::max(), 0.01, 0.01, true, \ + IS_POINTER), /* d1 * x1 <= d2 * y1 (i.e. abs_c <= abs_s) */ \ + std::make_tuple(C, 2.1, 4.0, 1.5, 3.4, false, IS_POINTER), \ + std::make_tuple(C, 2.1, 4.0, -1.5, 3.4, false, IS_POINTER), \ + std::make_tuple(C, 2.1, -4.0, 1.5, 3.4, false, IS_POINTER), \ + std::make_tuple(C, 2.1, -4.0, 1.5, -3.4, false, IS_POINTER), \ + std::make_tuple(C, 2.1, -4.0, -1.5, -3.4, false, \ + IS_POINTER), /* d1 * x1 <= d2 * y1 (i.e. abs_c <= \ + abs_s) && rescaling */ \ std::make_tuple(C, 2.1, scale_down_gen(), 1.5, 3.4, \ - false), \ + false, IS_POINTER), \ std::make_tuple(C, 2.1, scale_down_gen(), 1.5, \ - scale_down_gen(), false), \ + scale_down_gen(), false, IS_POINTER), \ std::make_tuple(C, 2.1, scale_up_gen(), 1.5, \ - scale_down_gen(), false), \ + scale_down_gen(), false, IS_POINTER), \ std::make_tuple(C, 2.1, scale_down_gen(), 1.5, \ - scale_up_gen(), \ - false), /* d1 * x1 <= d2 * y1 (i.e. abs_c <= \ - abs_s) && Underflow */ \ + scale_up_gen(), false, \ + IS_POINTER), /* d1 * x1 <= d2 * y1 (i.e. abs_c <= \ + abs_s) && Underflow */ \ std::make_tuple(C, std::numeric_limits::min(), \ std::numeric_limits::min(), 0.01, 0.01, \ - true), /* d1 * x1 <= d2 * y1 (i.e. abs_c <= \ - abs_s) && Overflow */ \ + true, IS_POINTER), /* d1 * x1 <= d2 * y1 (i.e. \ + abs_c <= abs_s) && Overflow */ \ std::make_tuple(C, 0.01, 0.01, \ std::numeric_limits::max(), \ - std::numeric_limits::max(), \ - true), /* Overflow all */ \ + std::numeric_limits::max(), true, \ + IS_POINTER), /* Overflow all */ \ std::make_tuple(C, std::numeric_limits::max(), \ std::numeric_limits::max(), \ std::numeric_limits::max(), \ - std::numeric_limits::max(), \ - true), /* Underflow all */ \ - std::make_tuple(C, std::numeric_limits::min(), \ - std::numeric_limits::min(), \ - std::numeric_limits::min(), \ - std::numeric_limits::min(), \ - true), /* Numeric limits of one parameter */ \ + std::numeric_limits::max(), true, \ + IS_POINTER), /* Underflow all */ \ + std::make_tuple( \ + C, std::numeric_limits::min(), \ + std::numeric_limits::min(), \ + std::numeric_limits::min(), \ + std::numeric_limits::min(), true, \ + IS_POINTER), /* Numeric limits of one parameter */ \ std::make_tuple(C, 1.0, 1.0, 1.0, \ - std::numeric_limits::max(), false), \ + std::numeric_limits::max(), false, \ + IS_POINTER), \ std::make_tuple(C, 1.0, 1.0, \ - std::numeric_limits::max(), 1.0, \ - false), \ + std::numeric_limits::max(), 1.0, false, \ + IS_POINTER), \ std::make_tuple(C, 1.0, std::numeric_limits::max(), \ - 1.0, 1.0, false), \ - std::make_tuple( \ - C, std::numeric_limits::max(), 1.0, 1.0, 1.0, \ - false), /* Case that creates an infinite loop on cblas */ \ + 1.0, 1.0, false, IS_POINTER), \ + std::make_tuple(C, std::numeric_limits::max(), 1.0, \ + 1.0, 1.0, false, \ + IS_POINTER), /* Case that creates an infinite \ + loop on cblas */ \ std::make_tuple(C, std::numeric_limits::min(), -2.2, \ std::numeric_limits::min(), \ - std::numeric_limits::min(), \ - true), /* Case that triggers underflow detection \ - on abs_c <= abs_s && s >= 0 */ \ + std::numeric_limits::min(), true, \ + IS_POINTER), /* Case that triggers underflow \ + detection on abs_c <= abs_s && s >= 0 */ \ std::make_tuple(C, 15.5, -2.2, \ std::numeric_limits::min(), \ - std::numeric_limits::min(), \ - false), /* Test for previous errors */ \ + std::numeric_limits::min(), false, \ + IS_POINTER), /* Test for previous errors */ \ std::make_tuple(C, 0.0516274, -0.197215, -0.270436, -0.157621, \ - false)) + false, IS_POINTER)) #ifdef SB_ENABLE_USM -INSTANTIATE_ROTMG_TESTS(combi_usm, "usm"); // instantiate usm tests +INSTANTIATE_ROTMG_TESTS(combi_usm, "usm", true); // instantiate usm tests +INSTANTIATE_ROTMG_TESTS(combi_usm_scalar, "usm", + false); // instantiate usm tests #endif -INSTANTIATE_ROTMG_TESTS(combi_buffer, "buf"); // instantiate buffer tests +INSTANTIATE_ROTMG_TESTS(combi_buffer, "buf", true); // instantiate buffer tests +INSTANTIATE_ROTMG_TESTS(combi_buffer_scalar, "buf", + false); // instantiate buffer tests template static std::string generate_name( const ::testing::TestParamInfo>& info) { std::string alloc; T d1, d2, x1, y1; - bool will_overflow; - BLAS_GENERATE_NAME(info.param, alloc, d1, d2, x1, y1, will_overflow); + bool will_overflow, is_pointer; + BLAS_GENERATE_NAME(info.param, alloc, d1, d2, x1, y1, will_overflow, + is_pointer); } #ifdef SB_ENABLE_USM BLAS_REGISTER_TEST_ALL(Rotmg_Usm, combination_t, combi_usm, generate_name); +BLAS_REGISTER_TEST_ALL(Rotmg_Usm_scalar, combination_t, combi_usm_scalar, + generate_name); #endif BLAS_REGISTER_TEST_ALL(Rotmg_Buffer, combination_t, combi_buffer, generate_name); +BLAS_REGISTER_TEST_ALL(Rotmg_Buffer_scalar, combination_t, combi_buffer_scalar, + generate_name); #undef INSTANTIATE_ROTMG_TESTS