From 19b0fedad0da040498d350200cbafeedb09d94ee Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI <104583441+OuadiElfarouki@users.noreply.github.com> Date: Tue, 24 Oct 2023 10:30:20 +0100 Subject: [PATCH] Enabed Complex data type for Gemm (#462) Added preliminary support for sycl::complex data types for GEMM operator along with the relevant unit tests. --- CMakeLists.txt | 2 + README.md | 2 +- cmake/CmakeFunctionHelper.cmake | 131 ++++++- common/include/common/float_comparison.hpp | 98 ++++- .../include/common/system_reference_blas.hpp | 15 + doc/Gemm.md | 2 +- include/blas_meta.h | 26 +- include/operations/blas_constants.h | 8 +- src/interface/blas3/backend/amd_gpu.hpp | 80 ++++- src/interface/blas3/backend/default_cpu.hpp | 55 ++- src/interface/blas3/backend/intel_gpu.hpp | 81 ++++- src/interface/blas3/backend/nvidia_gpu.hpp | 42 ++- src/interface/gemm_interface.hpp | 33 +- src/operations/blas1_trees.hpp | 19 +- src/operations/blas3/gemm_common.hpp | 19 +- src/operations/blas3/gemm_interleaved.hpp | 20 +- .../blas3/gemm_load_store_complex.hpp | 174 +++++++++ src/operations/blas3/gemm_local.hpp | 31 +- .../blas3/gemm_no_local_full_vec.hpp | 54 ++- .../blas3/gemm_no_local_partial_vec.hpp | 42 ++- src/operations/blas3/gemm_partial_local.hpp | 4 +- test/blas_test.hpp | 48 ++- test/blas_test_macros.hpp | 42 +++ test/unittest/CMakeLists.txt | 5 + .../blas3/blas3_gemm_batched_test.cpp | 64 ++++ test/unittest/blas3/blas3_gemm_common.hpp | 338 +++++++++++++++++- .../blas3/blas3_gemm_tall_skinny_test.cpp | 78 ++++ test/unittest/blas3/blas3_gemm_test.cpp | 118 ++++++ 28 files changed, 1514 insertions(+), 117 deletions(-) create mode 100644 src/operations/blas3/gemm_load_store_complex.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index a6b85f570..09785078f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,6 +106,7 @@ if(IMGDNN_DIR) endif() option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON) +option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for supported operators" ON) # CmakeFunctionHelper has to be included after any options that it depends on are declared. # These include: @@ -115,6 +116,7 @@ option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON) # * BLAS_DATA_TYPES # * BLAS_INDEX_TYPES # * NAIVE_GEMM +# * BLAS_ENABLE_COMPLEX include(CmakeFunctionHelper) if (INSTALL_HEADER_ONLY) diff --git a/README.md b/README.md index 5720ae145..c5383b73f 100644 --- a/README.md +++ b/README.md @@ -463,7 +463,7 @@ Some of the supported options are: | `BLAS_ENABLE_EXTENSIONS` | `ON`/`OFF` | Determines whether to enable portBLAS extensions (`ON` by default) | | `BLAS_DATA_TYPES` | `half;float;double` | Determines the floating-point types to instantiate BLAS operations for. Default is `float` | | `BLAS_INDEX_TYPES` | `int32_t;int64_t` | Determines the type(s) to use for `index_t` and `increment_t`. Default is `int` | - +| `BLAS_ENABLE_COMPLEX` | `ON`/`OFF` | Determines whether to enable Complex data type support *(GEMM Kernels only)* (`ON` by default) | ### Cross-Compile (ComputeCpp Only) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index 1be6ab8a8..2ae71bc5e 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -36,10 +36,30 @@ function(cpp_type output data) if (${data} STREQUAL "half") set(${output} "cl::sycl::half" PARENT_SCOPE) return() + elseif(${data} STREQUAL "complex") + set(${output} "cl::sycl::ext::oneapi::experimental::complex" PARENT_SCOPE) + return() + elseif(${data} STREQUAL "complex") + set(${output} "cl::sycl::ext::oneapi::experimental::complex" PARENT_SCOPE) + return() endif() set(${output} "${data}" PARENT_SCOPE) endfunction() +function(set_complex_list output input append) + set(output_temp "") + if(${append} STREQUAL "true") + foreach(data ${input}) + list(APPEND output_temp "${data};complex<${data}>") + endforeach(data) + else() + foreach(data ${input}) + list(APPEND output_temp "complex<${data}>") + endforeach(data) + endif() + set(${output} ${output_temp} PARENT_SCOPE) +endfunction(set_complex_list) + ## represent the list of bolean options set(boolean_list "true" "false") @@ -56,6 +76,9 @@ function(sanitize_file_name output file_name) set(${output} "${file_name}" PARENT_SCOPE) endfunction() +#List of operators supporting Complex Data types +set(COMPLEX_OPS "gemm" "gemm_launcher" "scal") + function(set_target_compile_def in_target) #setting compiler flag for backend if(${TUNING_TARGET} STREQUAL "INTEL_GPU") @@ -84,16 +107,31 @@ function(set_target_compile_def in_target) message(STATUS "Gemm vectorization support enabled for target ${in_target}") target_compile_definitions(${in_target} PUBLIC GEMM_VECTORIZATION_SUPPORT=1) endif() - + #setting const data type support if(BLAS_ENABLE_CONST_INPUT) target_compile_definitions(${in_target} PUBLIC BLAS_ENABLE_CONST_INPUT=1) endif() + #setting complex support + if(${BLAS_ENABLE_COMPLEX}) + if("${in_target}" IN_LIST COMPLEX_OPS) + message(STATUS "Complex Data type support enabled for target ${in_target}") + target_compile_definitions(${in_target} PUBLIC BLAS_ENABLE_COMPLEX=1) + endif() + endif() endfunction() # blas unary function for generating source code function(generate_blas_objects blas_level func) set(LOCATION "${PORTBLAS_GENERATED_SRC}/${blas_level}/${func}/") - foreach(data ${data_list}) + set(data_list_c ${data_list}) + # Extend data_list to complex for each data in list + # if target function is in COMPLEX_OPS + if(BLAS_ENABLE_COMPLEX) + if("${func}" IN_LIST COMPLEX_OPS) + set_complex_list(data_list_c "${data_list}" "true") + endif() + endif() + foreach(data ${data_list_c}) cpp_type(cpp_data ${data}) foreach(index ${index_list}) foreach(increment ${index_list}) @@ -234,7 +272,11 @@ function(add_gemm_configuration batch_type use_joint_matrix ) - if(NOT ("${data}" IN_LIST data_list)) + set(data_list_c ${data_list}) + if(BLAS_ENABLE_COMPLEX) + set_complex_list(data_list_c "${data_list}" "true") + endif() + if(NOT ("${data}" IN_LIST data_list_c)) # Data type not enabled, skip configuration return() endif() @@ -249,6 +291,9 @@ function(add_gemm_configuration cpp_type(cpp_data ${data}) foreach(symm_a ${boolean_list}) foreach(symm_b ${boolean_list}) + if ((${data} MATCHES "complex") AND (symm_a OR symm_b)) + continue() + endif() foreach(trans_a ${boolean_list}) foreach(trans_b ${boolean_list}) foreach(is_beta_zero ${boolean_list}) @@ -380,6 +425,32 @@ if(${TUNING_TARGET} STREQUAL "INTEL_GPU") "${data}" 64 "false" "false" "false" 64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false") endforeach() + if(BLAS_ENABLE_COMPLEX) + # Extract list of complex for each data in supported_types + # list for complex specific gemm configurations + set(data_list_c) + set_complex_list(data_list_c "${supported_types}" "false") + foreach(data ${data_list_c}) + add_gemm_configuration( + "${data}" 64 "true" "false" "false" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 4 8 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false") + if (${data} STREQUAL "complex") + add_gemm_configuration( + "${data}" 64 "true" "true" "true" + 64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + else() + add_gemm_configuration( + "${data}" 64 "true" "true" "true" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + endif() + endforeach() + endif() # BLAS_ENABLE_COMPLEX elseif(${TUNING_TARGET} STREQUAL "POWER_VR" AND NOT IMGDNN_DIR) set(supported_types "float" @@ -445,6 +516,35 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation "${data}" 64 "false" "false" "false" 64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false") endforeach() + if(BLAS_ENABLE_COMPLEX) + # Extract list of complex for each data in supported_types + # list for complex specific gemm configurations + set(data_list_c) + set_complex_list(data_list_c "${supported_types}" "false") + foreach(data ${data_list_c}) + if (${data} STREQUAL "complex") + add_gemm_configuration( + "${data}" 256 "true" "true" "true" + 64 1 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 1 1 4 4 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + else() + add_gemm_configuration( + "${data}" 256 "true" "true" "true" + 64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 1 1 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + endif() + endforeach() + endif() # BLAS_ENABLE_COMPLEX elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") set(supported_types "float" @@ -486,7 +586,18 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") add_gemm_configuration( "${data}" 256 "false" "true" "true" 128 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + endforeach() + if(BLAS_ENABLE_COMPLEX) + # Extract list of complex for each data in supported_types + # list for complex specific gemm configurations + set(data_list_c) + set_complex_list(data_list_c "${supported_types}" "false") + foreach(data ${data_list_c}) + add_gemm_configuration( + "${data}" 256 "false" "false" "true" + 64 2 2 16 16 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") endforeach() + endif() # BLAS_ENABLE_COMPLEX else() # default cpu backend set(supported_types "float" @@ -513,6 +624,20 @@ else() # default cpu backend "${data}" 64 "false" "false" "false" 64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false" "false") endforeach() + if(BLAS_ENABLE_COMPLEX) + # Extract list of complex for each data in supported_types + # list for complex specific gemm configurations + set(data_list_c) + set_complex_list(data_list_c "${supported_types}" "false") + foreach(data ${data_list_c}) + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 2 2 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false") + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 8 8 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") + endforeach() + endif() # BLAS_ENABLE_COMPLEX endif() add_library(${func} OBJECT ${gemm_sources}) set_target_compile_def(${func}) diff --git a/common/include/common/float_comparison.hpp b/common/include/common/float_comparison.hpp index 43f8f578b..1222ccc41 100644 --- a/common/include/common/float_comparison.hpp +++ b/common/include/common/float_comparison.hpp @@ -28,6 +28,9 @@ #include #include +#ifdef BLAS_ENABLE_COMPLEX +#include +#endif #ifdef BLAS_DATA_TYPE_HALF #if SYCL_LANGUAGE_VERSION < 202000 @@ -65,6 +68,23 @@ scalar_t abs(scalar_t value) noexcept { return std::abs(value); } +#ifdef BLAS_ENABLE_COMPLEX +template +bool isnan(std::complex value) noexcept { + return (isnan(value.real()) || isnan(value.imag())); +} + +template +bool isinf(std::complex value) noexcept { + return (isinf(value.real()) || isinf(value.imag())); +} + +template +scalar_t abs(std::complex value) noexcept { + return std::abs(value); +} +#endif + #ifdef BLAS_DATA_TYPE_HALF template <> inline bool isnan(cl::sycl::half value) noexcept { @@ -172,7 +192,7 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) { return true; } - const scalar_t absolute_diff = utils::abs(scalar1 - scalar2); + const auto absolute_diff = utils::abs(scalar1 - scalar2); // Close to zero, the relative error doesn't work, use absolute error if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} || @@ -212,6 +232,37 @@ inline bool compare_vectors(std::vector const& vec, return true; } +#ifdef BLAS_ENABLE_COMPLEX +/** + * Compare two vectors of complex data and returns false if the difference is + * not acceptable. The second vector is considered the reference. + * @tparam scalar_t the type of complex underying data present in the input + * vectors + * @tparam epilon_t the type used as tolerance. + */ +template +inline bool compare_vectors(std::vector> const& vec, + std::vector> const& ref, + std::ostream& err_stream = std::cerr, + std::string end_line = "\n") { + if (vec.size() != ref.size()) { + err_stream << "Error: tried to compare vectors of different sizes" + << std::endl; + return false; + } + + for (int i = 0; i < vec.size(); ++i) { + if (!almost_equal, epsilon_t>(vec[i], ref[i])) { + err_stream << "Value mismatch at index " << i << ": (" << vec[i].real() + << "," << vec[i].imag() << "); expected (" << ref[i].real() + << "," << ref[i].imag() << ")" << end_line; + return false; + } + } + return true; +} +#endif + /** * Compare two vectors at a given stride and window (unit_vec_size) and returns * false if the difference is not acceptable. The second vector is considered @@ -253,6 +304,51 @@ inline bool compare_vectors_strided(std::vector const& vec, return true; } +#ifdef BLAS_ENABLE_COMPLEX +/** + * Compare two vectors of complex data at a given stride and window and returns + * false if the difference is not acceptable. The second vector is considered + * the reference. + * @tparam scalar_t the type of the complex underying data present in the input + * vectors + * @tparam epsilon_t the type used as tolerance. + * @param stride is the stride between two consecutive 'windows' + * @param window is the size of a comparison window + */ +template +inline bool compare_vectors_strided( + std::vector> const& vec, + std::vector> const& ref, int stride, int window, + std::ostream& err_stream = std::cerr, std::string end_line = "\n") { + if (vec.size() != ref.size()) { + err_stream << "Error: tried to compare vectors of different sizes" + << std::endl; + return false; + } + + int k = 0; + + // Loop over windows + while (window + (k + 1) * stride < vec.size()) { + // Loop within a window + for (int i = 0; i < window; ++i) { + auto index = i + k * stride; + if (!almost_equal, epsilon_t>(vec[index], + ref[index])) { + err_stream << "Value mismatch at index " << index << ": (" + << vec[index].real() << "," << vec[index].imag() + << "); expected (" << ref[index].real() << "," + << ref[index].imag() << ")" << end_line; + return false; + } + } + k += 1; + } + + return true; +} +#endif + } // namespace utils #endif // UTILS_FLOAT_COMPARISON_H_ diff --git a/common/include/common/system_reference_blas.hpp b/common/include/common/system_reference_blas.hpp index afcb4f5e4..cd07e27cf 100644 --- a/common/include/common/system_reference_blas.hpp +++ b/common/include/common/system_reference_blas.hpp @@ -133,6 +133,12 @@ auto blas_system_function(floatfn_t ffn, doublefn_t dfn) return BlasSystemFunction::get(ffn, dfn); } +template +auto blas_cplx_system_function(floatfn_t ffn, doublefn_t dfn) + -> decltype(BlasSystemFunction::get(ffn, dfn)) { + return BlasSystemFunction::get(ffn, dfn); +} + // ======= // Level 1 // ======= @@ -378,6 +384,15 @@ void gemm(const char *transA, const char *transB, int m, int n, int k, lda, b, ldb, beta, c, ldc); } +template +void cgemm(const char *transA, const char *transB, int m, int n, int k, + const void *alpha, const void *a, int lda, const void *b, int ldb, + const void *beta, void *c, int ldc) { + auto func = blas_cplx_system_function(&cblas_cgemm, &cblas_zgemm); + func(CblasColMajor, c_trans(*transA), c_trans(*transB), m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc); +} + template void trsm(const char *side, const char *uplo, const char *trans, const char *diag, int m, int n, scalar_t alpha, const scalar_t A[], diff --git a/doc/Gemm.md b/doc/Gemm.md index 0264e3d4c..653549212 100644 --- a/doc/Gemm.md +++ b/doc/Gemm.md @@ -100,7 +100,7 @@ The core of the `GEMM` computation is as follows: ## Vectorized Loading/Storing -Many of the `GEMM` kernels support vectorized loads/stores using functions located in `gemm_load_store.hpp` in `src/operations/blas3/` . +Many of the `GEMM` kernels support vectorized loads/stores using functions located in `gemm_load_store.hpp` in `src/operations/blas3/`*(this feature is limited to non-complex data types)*. These functions are pretty simple but there are some special considerations for how they are used, particularly around whether the matrices are transposed or not. If a matrix is transposed this changes the data layout such that elements are no longer contiguous in memory. diff --git a/include/blas_meta.h b/include/blas_meta.h index 6bad4be98..d39a395f5 100644 --- a/include/blas_meta.h +++ b/include/blas_meta.h @@ -29,6 +29,11 @@ #include #include #include +#ifdef BLAS_ENABLE_COMPLEX +#define SYCL_EXT_ONEAPI_COMPLEX +#include +#include +#endif namespace blas { @@ -162,7 +167,7 @@ int append_vector(vector_t &lhs_vector, vector_t const &rhs_vector) { template first_vector_t concatenate_vectors(first_vector_t first_vector, - other_vector_t &&... other_vectors) { + other_vector_t &&...other_vectors) { int first_Vector_size = static_cast(first_vector.size()); int s[] = {vec_total_size(first_Vector_size, other_vectors)..., 0}; first_vector.reserve(first_Vector_size); @@ -190,6 +195,25 @@ struct is_sycl_scalar : std::false_type {}; template <> struct is_sycl_scalar : std::false_type {}; +#ifdef BLAS_ENABLE_COMPLEX +// SYCL Complex type alias +template +using complex_sycl = typename cl::sycl::ext::oneapi::experimental::complex; + +template +struct is_complex_sycl + : std::integral_constant> || + std::is_same_v>> {}; + +template +struct is_complex_std + : std::integral_constant> || + std::is_same_v>> {}; + +#endif + } // namespace blas #endif // BLAS_META_H diff --git a/include/operations/blas_constants.h b/include/operations/blas_constants.h index 103c78152..637f23f95 100644 --- a/include/operations/blas_constants.h +++ b/include/operations/blas_constants.h @@ -202,13 +202,15 @@ struct constant, const_val::collapse> { } }; +#ifdef BLAS_ENABLE_COMPLEX template -struct constant, Indicator> { - constexpr static PORTBLAS_INLINE std::complex value() { - return std::complex(constant::value(), +struct constant, Indicator> { + constexpr static PORTBLAS_INLINE complex_sycl value() { + return complex_sycl(constant::value(), constant::value()); } }; +#endif #ifdef BLAS_DATA_TYPE_HALF template <> diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index be864ae76..f494f25b9 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -33,13 +33,14 @@ namespace backend { template -typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, - container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, - gemm_batch_type_t batch_type, - const typename sb_handle_t::event_t& _dependencies) { +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { static constexpr int ClSize = 64; static constexpr int tileWgSize = ClSize / sizeof(element_t); if (batch_type == gemm_batch_type_t::interleaved) { @@ -118,10 +119,65 @@ typename sb_handle_t::event_t _gemm( } } else #endif // GEMM_TALL_SKINNY_SUPPORT - if (_M * _N <= 65536) { + if (_M * _N <= 65536) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, false, + ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, false, + ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 2, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } +} + +// Complex Configurations +#ifdef BLAS_ENABLE_COMPLEX +template +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + static constexpr int ClSize = 64; + static constexpr int tileWgSize = ClSize / sizeof(element_t); +/* Tall & Skinny matrices. */ +#ifdef GEMM_TALL_SKINNY_SUPPORT + if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8)) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, true, true, true, + ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, false, false, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +#endif + if (_M * _N <= 65536) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -132,16 +188,18 @@ typename sb_handle_t::event_t _gemm( } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 2, + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided)>:: template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); } } +#endif + } // namespace backend } // namespace gemm } // namespace blas diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 17868991e..e62348363 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -33,13 +33,14 @@ namespace backend { template -typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, - container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, - gemm_batch_type_t batch_type, - const typename sb_handle_t::event_t& _dependencies) { +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { if (batch_type == gemm_batch_type_t::interleaved) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, @@ -101,6 +102,46 @@ typename sb_handle_t::event_t _gemm( #endif } + +// Complex Configurations +#ifdef BLAS_ENABLE_COMPLEX +template +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + if (_M <= 256 && _N <= 256 && _K <= 256) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<2, 2, 4, 4>, _t_a, _t_b, false, false, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<8, 8, 4, 4>, _t_a, _t_b, false, false, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +} +#endif + } // namespace backend } // namespace gemm } // namespace blas diff --git a/src/interface/blas3/backend/intel_gpu.hpp b/src/interface/blas3/backend/intel_gpu.hpp index 8fcb3e3a8..8d788c9b5 100644 --- a/src/interface/blas3/backend/intel_gpu.hpp +++ b/src/interface/blas3/backend/intel_gpu.hpp @@ -32,13 +32,14 @@ namespace backend { template -typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, - container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, - gemm_batch_type_t batch_type, - const typename sb_handle_t::event_t& _dependencies) { +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { if (batch_type == gemm_batch_type_t::interleaved) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, @@ -206,6 +207,72 @@ typename sb_handle_t::event_t _gemm( batch_size, _dependencies); } } + +// Complex Configurations +#ifdef BLAS_ENABLE_COMPLEX +template +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { +#ifdef GEMM_TALL_SKINNY_SUPPORT + if (batch_size == 1) { + constexpr int wg_size = sizeof(element_t) == 16 ? 4 : 8; + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, true, true, true, 64, + Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, false, false, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +#endif + if (_M <= 128 && _N <= 128) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, true, false, false, 64, + Tile<4, 4, 8, 8>, _t_a, _t_b, false, false, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else if (_t_b && !_t_a) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<8, 8, 8, 8>, _t_a, _t_b, false, false, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<4, 8, 16, 8>, _t_a, _t_b, false, false, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +} +#endif + } // namespace backend } // namespace gemm } // namespace blas diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index aeb678704..13966172e 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -33,13 +33,14 @@ namespace backend { template -typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, - container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, - gemm_batch_type_t batch_type, - const typename sb_handle_t::event_t& _dependencies) { +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { if (batch_type == gemm_batch_type_t::interleaved) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, @@ -167,6 +168,33 @@ typename sb_handle_t::event_t _gemm( _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); } + +// Complex Configurations +#ifdef BLAS_ENABLE_COMPLEX +template +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, true, 64, + Tile<2, 2, 16, 16, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, + false, false, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); +} +#endif + } // namespace backend } // namespace gemm } // namespace blas diff --git a/src/interface/gemm_interface.hpp b/src/interface/gemm_interface.hpp index a5c2c7bb3..8e90a4b82 100644 --- a/src/interface/gemm_interface.hpp +++ b/src/interface/gemm_interface.hpp @@ -48,6 +48,22 @@ namespace blas { */ namespace internal { +// Check whether value is zero (complex & float/double) +template +inline typename std::enable_if::value, bool>::type isZero( + const T& value) { + return (value == static_cast(0)); +} + +#ifdef BLAS_ENABLE_COMPLEX +template +inline typename std::enable_if::value, bool>::type isZero( + const T& value) { + using value_t = typename T::value_type; + return (value == T(value_t(0), value_t(0))); +} +#endif + template @@ -73,15 +89,14 @@ typename sb_handle_t::event_t _gemm_is_beta_zero( container_2_t _C, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { - return ((_beta == static_cast(0)) - ? _gemm_platform_specific<_t_a, _t_b, s_a, s_b, true>( - sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, b_, _ldb, - _strideb, _beta, _C, _ldc, _stridec, batch_size, batch_type, - _dependencies) - : _gemm_platform_specific<_t_a, _t_b, s_a, s_b, false>( - sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, b_, _ldb, - _strideb, _beta, _C, _ldc, _stridec, batch_size, batch_type, - _dependencies)); + return isZero(_beta) ? _gemm_platform_specific<_t_a, _t_b, s_a, s_b, true>( + sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, + b_, _ldb, _strideb, _beta, _C, _ldc, _stridec, + batch_size, batch_type, _dependencies) + : _gemm_platform_specific<_t_a, _t_b, s_a, s_b, false>( + sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, + b_, _ldb, _strideb, _beta, _C, _ldc, _stridec, + batch_size, batch_type, _dependencies); } template { }; #endif // BLAS_DATA_TYPE_HALF -/*! DetectScalar. - * @brief See Detect Scalar. - */ -template <> -struct DetectScalar> { - using element_t = std::complex; - static element_t get_scalar(element_t &scalar) { return scalar; } -}; - -/*! DetectScalar. +#ifdef BLAS_ENABLE_COMPLEX +/*! DetectScalar (for sycl::complex) * @brief See Detect Scalar. */ -template <> -struct DetectScalar> { - using element_t = std::complex; +template +struct DetectScalar> { + using element_t = complex_sycl; static element_t get_scalar(element_t &scalar) { return scalar; } }; +#endif /*! get_scalar. * @brief Template autodecuction function for DetectScalar. diff --git a/src/operations/blas3/gemm_common.hpp b/src/operations/blas3/gemm_common.hpp index 4966b9f13..670dc340d 100644 --- a/src/operations/blas3/gemm_common.hpp +++ b/src/operations/blas3/gemm_common.hpp @@ -33,6 +33,22 @@ namespace blas { +#ifdef BLAS_ENABLE_COMPLEX +template +static PORTBLAS_INLINE T +mul_add(T a, T b, T c, + typename std::enable_if::value>::type * = 0) { + return (a * b + c); +} +#endif + +template +static PORTBLAS_INLINE T +mul_add(T a, T b, T c, + typename std::enable_if::value>::type * = 0) { + return (cl::sycl::mad(a, b, c)); +} + template struct type_string { static const char *get_value() { return "unknown"; } @@ -62,7 +78,8 @@ template PORTBLAS_INLINE std::string Tile::get_type_string() noexcept { + ItemBatchs, WgBatchs, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>::get_type_string() noexcept { std::ostringstream str{}; str << "Tile<" << item_rows << ", " << item_cols << ", " << wg_rows << ", " << wg_cols << ", " << sg_rows << ", " << sg_cols << ", " << tl_rows diff --git a/src/operations/blas3/gemm_interleaved.hpp b/src/operations/blas3/gemm_interleaved.hpp index 551bb465a..66629033e 100644 --- a/src/operations/blas3/gemm_interleaved.hpp +++ b/src/operations/blas3/gemm_interleaved.hpp @@ -146,6 +146,11 @@ class Gemm::value, + "Interleaved GEMM is not supported for Complex Data types"); +#endif + input_t a_; input_t b_; output_t c_; @@ -159,10 +164,9 @@ class Gemm PORTBLAS_INLINE void compute_panel(check_t boundary_check, index_t m_stride, - index_t n_stride, index_t mb_start, - index_t m_start, index_t n_start, - in_ptr_t A, in_ptr_t B, out_ptr_t C) { + index_t n_stride, index_t mb_start, + index_t m_start, index_t n_start, + in_ptr_t A, in_ptr_t B, out_ptr_t C) { packet_type reg_a[item_rows * item_batchs / VectorSize]; packet_type reg_b[item_cols * item_batchs / VectorSize]; packet_type reg_res[item_rows * item_cols * item_batchs / VectorSize]; @@ -482,7 +486,7 @@ class Gemm container + * github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc + * and only supports size = 1. + * @tparam DataT Complex type of the vector's data + * @tparam NumElements Elements count of the vector (only 1 is supported) + */ +template +class vec_complex { + static_assert(NumElements == 1, + "Vector wrapper arround sycl::complex of size>1 unsupported."); + using address_t = cl::sycl::access::address_space; + using decorated_t = cl::sycl::access::decorated; + using DataType = DataT; + static constexpr int getNumElements() { return NumElements; } + size_t size() const noexcept { return NumElements; } + + private: + DataType m_Data; + + public: + vec_complex() = default; + + constexpr vec_complex(const vec_complex &rhs) = default; + constexpr vec_complex(vec_complex &&rhs) = default; + constexpr vec_complex &operator=(const vec_complex &rhs) = default; + + vec_complex(const DataType &rhs_data) : m_Data{rhs_data} {} + + // Conversion operator (valid with NumElements==1) + operator DataT() const { return m_Data; } + + // Subscript operators + DataT &operator[](int i) { + assert(i < NumElements); + return (m_Data); + } + const DataT &operator[](int i) const { + assert(i < NumElements); + return (m_Data); + } + + // Binary Ops + // Multiply + vec_complex operator*(const vec_complex &rhs) { + return (vec_complex{m_Data * static_cast(rhs)}); + } + + vec_complex operator*(const DataType &rhs) { + return (vec_complex{m_Data * rhs}); + } + + // Compound Multiply + vec_complex &operator*=(const DataType &rhs) { + this->m_Data = this->m_Data * rhs; + return (*this); + } + + vec_complex &operator*=(const vec_complex &rhs) { + this->m_Data = this->m_Data * static_cast(rhs); + return (*this); + } + + // Add + vec_complex operator+(const vec_complex &rhs) { + return (vec_complex{m_Data + static_cast(rhs)}); + } + + vec_complex operator+(const DataType &rhs) { + return (vec_complex{m_Data + rhs}); + } + + // Compound Add + vec_complex &operator+=(const DataType &rhs) { + this->m_Data = this->m_Data * rhs; + return (*this); + } + + vec_complex &operator+=(const vec_complex &rhs) { + this->m_Data = this->m_Data + static_cast(rhs); + return (*this); + } + + // Load + template + void load(size_t Offset, + cl::sycl::multi_ptr Ptr) { + m_Data = *(Ptr + Offset * NumElements); + } + + // Store + template + void store(size_t Offset, + cl::sycl::multi_ptr Ptr) const { + *(Ptr + Offset * NumElements) = m_Data; + } +}; + +/*! @brief Partial specialization of the Packetize class dedicated to +sycl::complex types. It contains static methods for loading and storing size=1 +complex packets from/to memory. +* @tparam vector_size The desired vector size to be used. Only size = 1 is +supported so far. +* @tparam value_t The complex type of the matrix data. +*/ +template +struct Packetize, index_t> { + // Vectorization is not enabled for complex, always set to 1 + using value_t = complex_sycl; + using PacketType = vec_complex; + static constexpr int packet_size = 1; + template + static PORTBLAS_INLINE constexpr bool check_size() { + return true; + } + + /*! @brief Performs a non-vectorised load of sycl::complex data element while + * whether block is internal or not since vectorization is not enabled for + * complex types yet. + * @tparam trans Whether the source matrix is transposed or not. + * @tparam internal True if the current block is internal and no bounds + * checking is required. + * @tparam ld The leading dimension of the destination memory. */ + template + static PORTBLAS_INLINE void load(const bool in_range, SrcPointerType src, + DestPointerType dest, + EdgePredicate edge_in_range) { + *(dest) = in_range ? *(src) : value_t{(T)0, (T)0}; + } + + /*! @brief Store a size = 1 vector packet of sycl::complex data into local + * memory (whether source is transposed or not since it's only 1 element). + * @tparam trans Whether the source matrix is transposed or not. + * @tparam ld The leading dimension of the destination memory.*/ + template + static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) { + *dest = packet[0]; + } +}; +#endif +} // namespace blas + +#endif // PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP diff --git a/src/operations/blas3/gemm_local.hpp b/src/operations/blas3/gemm_local.hpp index 9b1c1c98b..0ca182918 100644 --- a/src/operations/blas3/gemm_local.hpp +++ b/src/operations/blas3/gemm_local.hpp @@ -27,6 +27,9 @@ #include "gemm_common.hpp" #include "gemm_load_store.hpp" +#ifdef BLAS_ENABLE_COMPLEX +#include "gemm_load_store_complex.hpp" +#endif namespace blas { @@ -142,6 +145,12 @@ class Gemm::value) || + is_sycl_scalar::value, + "Vector size should be equal to 1 for Complex Data types"); +#endif + //! @brief leading dimension of block of A in local static constexpr index_t ldsa = block_rows + nbc_a; //! @brief leading dimension of block of B in local @@ -162,8 +171,8 @@ class Gemm PORTBLAS_INLINE void eval(local_memory_t scratch_acc, - const cl::sycl::nd_item<1> &id) noexcept { + const cl::sycl::nd_item<1> &id) noexcept { index_t m = a_.get_size_row(); index_t n = b_.get_size_col(); const index_t k = a_.get_size_col(); @@ -546,9 +555,9 @@ class Gemm PORTBLAS_INLINE void store_output_block(index_t, index_t mc, index_t nc, - OutputPointerType C, index_t ldc, - element_t *reg_res, - const bool out_of_range) noexcept { + OutputPointerType C, index_t ldc, + element_t *reg_res, + const bool out_of_range) noexcept { if (out_of_range) { return; } @@ -726,9 +735,9 @@ class Gemm PORTBLAS_INLINE void compute_block_gemm(index_t, InputPointerType B, - InputPointerType A, element_t *reg_a, - element_t ®_b, - element_t *reg_res) noexcept { + InputPointerType A, element_t *reg_a, + element_t ®_b, + element_t *reg_res) noexcept { // NOTE: Adding "#pragma unroll" here reduces performance on AMD R9 // Nano. // Seems that the small reduction of arithmetic operations does @@ -754,7 +763,7 @@ class Gemm(reg_a[l], reg_b, reg_res[j * item_rows + l]); } } A = A + ldsa; @@ -781,7 +790,7 @@ class Gemm static PORTBLAS_INLINE typename std::enable_if::type sync_smem( const cl::sycl::nd_item<1> &id, index_t &ofs_sign, P &s, - Ps &... ss) noexcept { + Ps &...ss) noexcept { s += ofs_sign * o; sync_smem(id, ofs_sign, ss...); } diff --git a/src/operations/blas3/gemm_no_local_full_vec.hpp b/src/operations/blas3/gemm_no_local_full_vec.hpp index a5dc683f3..77cbafbbf 100644 --- a/src/operations/blas3/gemm_no_local_full_vec.hpp +++ b/src/operations/blas3/gemm_no_local_full_vec.hpp @@ -27,6 +27,9 @@ #include "gemm_common.hpp" #include "gemm_load_store.hpp" +#ifdef BLAS_ENABLE_COMPLEX +#include "gemm_load_store_complex.hpp" +#endif namespace blas { @@ -69,6 +72,7 @@ class Gemm::type; using address_t = cl::sycl::access::address_space; using packetize_t = Packetize; + using vector_t = typename packetize_t::PacketType; static constexpr int local_memory_size = 0; /*! @brief The number of rows processed by each work item */ static constexpr index_t item_rows = tile_type::item_rows; @@ -103,6 +107,12 @@ class Gemm(), "If vectorization is enabled item_cols must equal the packet_size"); +#ifdef BLAS_ENABLE_COMPLEX + static_assert((VectorSize == 1 && is_complex_sycl::value) || + is_sycl_scalar::value, + "Vector size should be equal to 1 for Complex Data types"); +#endif + input_t a_; input_t b_; output_t c_; @@ -114,8 +124,8 @@ class Gemm(check_boundary( dim_m_c_start + j * wg_rows, dim_n_c_start + i * wg_cols))) { - cl::sycl::vec out_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t out_vec{}; out_vec.template load( 0, cl::sycl::multi_ptr( @@ -552,7 +564,9 @@ class Gemm(is_valid_row(j * ptr_next + work_per_load - 1)); - cl::sycl::vec in_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{}; if (in_range) { // if in range perform a vectorised load in_vec.template load( @@ -630,7 +644,9 @@ class Gemm(is_valid_col(work_per_load - 1)); - cl::sycl::vec in_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{}; if (in_range) { // if in range perform a vectorised load in_vec.template load( @@ -705,7 +721,9 @@ class Gemm(is_valid_row(work_per_load - 1)) && do_check(is_valid_col(col_ofs)); - cl::sycl::vec in_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{}; if (in_range) { // If in range perform a vectorised load. in_vec.template load( @@ -768,7 +786,9 @@ class Gemm(is_valid_row(row_ofs)) && do_check(is_valid_col(work_per_load - 1)); - cl::sycl::vec in_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{}; if (in_range) { // If in range perform a vectorised load. in_vec.template load( @@ -808,7 +828,7 @@ class Gemm(reg_a[j], reg_b[i], reg_res[i * item_rows + j]); } } } @@ -860,7 +880,7 @@ class Gemm(reg_a[j], *reg_b, reg_res[j]); } } @@ -887,11 +907,11 @@ class Gemm PORTBLAS_INLINE void store(PointerType C, element_t *reg_res, - const index_t &dim_m_c_start, - const index_t &dim_n_c_start, - const check_boundary &chk_boundary, - const bool out_of_range, - const index_t &ldc) noexcept { + const index_t &dim_m_c_start, + const index_t &dim_n_c_start, + const check_boundary &chk_boundary, + const bool out_of_range, + const index_t &ldc) noexcept { if (out_of_range) { return; } @@ -901,7 +921,9 @@ class Gemm(chk_boundary(dim_m_c_start + j * wg_rows, dim_n_c_start + i * wg_cols))) { - cl::sycl::vec out_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t out_vec{}; out_vec.template load( 0, cl::sycl::multi_ptr( diff --git a/src/operations/blas3/gemm_no_local_partial_vec.hpp b/src/operations/blas3/gemm_no_local_partial_vec.hpp index eb3d19473..ba26ef67f 100644 --- a/src/operations/blas3/gemm_no_local_partial_vec.hpp +++ b/src/operations/blas3/gemm_no_local_partial_vec.hpp @@ -27,6 +27,9 @@ #include "gemm_common.hpp" #include "gemm_load_store.hpp" +#ifdef BLAS_ENABLE_COMPLEX +#include "gemm_load_store_complex.hpp" +#endif namespace blas { @@ -69,6 +72,7 @@ class Gemm::type; using address_t = cl::sycl::access::address_space; using packetize_t = Packetize; + using vector_t = typename packetize_t::PacketType; static constexpr int local_memory_size = 0; /*! @brief The number of rows processed by each work item */ static constexpr index_t item_rows = tile_type::item_rows; @@ -99,6 +103,12 @@ class Gemm::value) || + is_sycl_scalar::value, + "Vector size should be equal to 1 for Complex Data types"); +#endif + input_t a_; input_t b_; output_t c_; @@ -110,8 +120,8 @@ class Gemm PORTBLAS_INLINE void load(PointerType ptr, element_t *reg, const index_t &ld, - index_t index, const check_boundary &chk_boundary, - const bool out_of_range) noexcept { + index_t index, const check_boundary &chk_boundary, + const bool out_of_range) noexcept { if (out_of_range) { return; } @@ -458,7 +468,9 @@ class Gemm(chk_boundary(index + (work_per_load - 1))); - cl::sycl::vec in_vec{0}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{0}; if (in_range) { in_vec.template load( 0, @@ -488,7 +500,7 @@ class Gemm(reg_a[j], reg_b[i], reg_res[i * item_rows + j]); } } } @@ -502,7 +514,9 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type store_packet( element_t *reg, OutputPointerType out_ptr) { - cl::sycl::vec out_vec{0}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t out_vec{0}; out_vec.template load( 0, cl::sycl::multi_ptr(reg)); @@ -531,11 +545,11 @@ class Gemm PORTBLAS_INLINE void store(PointerType C, element_t *reg_res, - const index_t &dim_m_c_start, - const index_t &dim_n_c_start, - const check_boundary &chk_boundary, - const bool out_of_range, - const index_t &ldc) noexcept { + const index_t &dim_m_c_start, + const index_t &dim_n_c_start, + const check_boundary &chk_boundary, + const bool out_of_range, + const index_t &ldc) noexcept { if (out_of_range) { return; } @@ -545,7 +559,9 @@ class Gemm(chk_boundary(dim_m_c_start + j * wg_rows, dim_n_c_start + i * wg_cols))) { - cl::sycl::vec out_vec{0}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t out_vec{0}; out_vec.template load( 0, cl::sycl::multi_ptr( diff --git a/src/operations/blas3/gemm_partial_local.hpp b/src/operations/blas3/gemm_partial_local.hpp index a9de19fb8..a6f8bf30a 100644 --- a/src/operations/blas3/gemm_partial_local.hpp +++ b/src/operations/blas3/gemm_partial_local.hpp @@ -309,8 +309,8 @@ class GemmPartial( + privateLhs, privateRhs, private_res[wLPTM + idx]); lhs_index += tile_type::wg_rows; } diff --git a/test/blas_test.hpp b/test/blas_test.hpp index 1d0f39de3..d159109db 100644 --- a/test/blas_test.hpp +++ b/test/blas_test.hpp @@ -149,6 +149,34 @@ static inline void fill_random(std::vector &vec) { fill_random_with_range(vec, scalar_t{-2}, scalar_t{5}); } +#ifdef BLAS_ENABLE_COMPLEX +/** + * @brief Generates a random vector of std::complex values, using a + * uniform distribution. + * @param vec Input vector to fill + * @param rangeMin Minimum value for the uniform distribution (real & imag) + * @param rangeMax Maximum value for the uniform distribution (real & imag) + */ +template +static inline void fill_random_with_range( + std::vector> &vec, scalar_t rangeMin, + scalar_t rangeMax) { + for (std::complex &e : vec) { + e = std::complex{random_scalar(rangeMin, rangeMax), + random_scalar(rangeMin, rangeMax)}; + } +} + +/** + * @brief Generates a random vector of std::complex values, using a + * uniform distribution. + */ +template +static inline void fill_random(std::vector> &vec) { + fill_random_with_range(vec, scalar_t{-2}, scalar_t{5}); +} +#endif + /** * @brief Fills a lower or upper triangular matrix suitable for TRSM testing * @param A The matrix to fill. Size must be at least m * lda @@ -165,7 +193,7 @@ static inline void fill_random(std::vector &vec) { * @param unused Value to put in the unused parts of the matrix */ template -static inline void fill_trsm_matrix(std::vector& A, size_t k, +static inline void fill_trsm_matrix(std::vector &A, size_t k, size_t lda, char uplo, char unit_diag, scalar_t diag = scalar_t{1}, scalar_t unused = scalar_t{0}) { @@ -262,6 +290,24 @@ struct dump_arg_helper { } }; +#ifdef BLAS_ENABLE_COMPLEX +/** Specialization of dump_arg_helper for std::complex types. + * This is required to split the real & imag parts properly and avoid + * by-default parentheses format. + **/ +template +struct dump_arg_helper< + T, typename std::enable_if::value>::type> { + inline void operator()(std::ostream &ss, T f) { + using scalar_t = typename T::value_type; + dump_arg_helper{}(ss, f.real()); + ss << "r"; + dump_arg_helper{}(ss, f.imag()); + ss << "i"; + } +}; +#endif + /** * Type of the tested api */ diff --git a/test/blas_test_macros.hpp b/test/blas_test_macros.hpp index 5b4cf979c..89e733e60 100644 --- a/test/blas_test_macros.hpp +++ b/test/blas_test_macros.hpp @@ -93,6 +93,36 @@ combination, name_generator) #endif // BLAS_DATA_TYPE_HALF +#ifdef BLAS_ENABLE_COMPLEX +#define BLAS_REGISTER_TEST_CPLX_S_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) \ + class class_name##CplxFloat \ + : public ::testing::TestWithParam> {}; \ + TEST_P(class_name##CplxFloat, test) { test_function(GetParam()); }; \ + INSTANTIATE_TEST_SUITE_P(test_suite, class_name##CplxFloat, \ + combination, name_generator); +#else +#define BLAS_REGISTER_TEST_CPLX_S_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) +#endif // BLAS_ENABLE_COMPLEX + +#if defined(BLAS_DATA_TYPE_DOUBLE) & defined(BLAS_ENABLE_COMPLEX) +#define BLAS_REGISTER_TEST_CPLX_D_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) \ + class class_name##CplxDouble \ + : public ::testing::TestWithParam> {}; \ + TEST_P(class_name##CplxDouble, test) { test_function(GetParam()); }; \ + INSTANTIATE_TEST_SUITE_P(test_suite, class_name##CplxDouble, \ + combination, name_generator); +#else +#define BLAS_REGISTER_TEST_CPLX_D_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) +#endif // BLAS_ENABLE_COMPLEX & BLAS_ENABLE_COMPLEX + /** Registers test for all supported data types * @param test_suite Name of the test suite * @param class_name Base name of the test class @@ -115,6 +145,18 @@ combination_t, combination, \ name_generator); +#ifdef BLAS_ENABLE_COMPLEX +#define BLAS_REGISTER_CPLX_TEST_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) \ + BLAS_REGISTER_TEST_CPLX_S_CUSTOM_NAME(test_suite, class_name, test_function, \ + combination_t, combination, \ + name_generator); \ + BLAS_REGISTER_TEST_CPLX_D_CUSTOM_NAME(test_suite, class_name, test_function, \ + combination_t, combination, \ + name_generator); +#endif // BLAS_ENABLE_COMPLEX + /** Registers test for all supported data types * @see BLAS_REGISTER_TEST_CUSTOM_NAME */ diff --git a/test/unittest/CMakeLists.txt b/test/unittest/CMakeLists.txt index 4f824238d..b4d2b0a3b 100644 --- a/test/unittest/CMakeLists.txt +++ b/test/unittest/CMakeLists.txt @@ -116,6 +116,11 @@ foreach(blas_test ${SYCL_UNITTEST_SRCS}) if(STRESS_TESTING) target_compile_definitions(${test_exec} PRIVATE STRESS_TESTING) endif() + if(${BLAS_ENABLE_COMPLEX}) + if(${test_exec} MATCHES "gemm") + target_compile_definitions(${test_exec} PRIVATE BLAS_ENABLE_COMPLEX=1) + endif() + endif() target_compile_definitions(${test_exec} PRIVATE -DBLAS_INDEX_T=${BLAS_TEST_INDEX_TYPE}) target_link_libraries(${test_exec} PRIVATE gtest_main Clara::Clara blas::blas portblas) target_include_directories(${test_exec} PRIVATE ${CBLAS_INCLUDE} ${PORTBLAS_COMMON_INCLUDE_DIR}) diff --git a/test/unittest/blas3/blas3_gemm_batched_test.cpp b/test/unittest/blas3/blas3_gemm_batched_test.cpp index 1ce9413bd..824bf656b 100644 --- a/test/unittest/blas3/blas3_gemm_batched_test.cpp +++ b/test/unittest/blas3/blas3_gemm_batched_test.cpp @@ -145,3 +145,67 @@ const auto AllStridedBatched = ::testing::Values(1, 2, 3) // stride_c_mul ); GENERATE_GEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, AllStridedBatched); + +#ifdef BLAS_ENABLE_COMPLEX +template +const auto CplxBetaNonZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(3), // batch + ::testing::Values(63, 128), // m + ::testing::Values(63, 128), // n + ::testing::Values(63, 128), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(BatchGemm, CplxBetaNonZeroLDMatch); + +template +const auto CplxDefaultGemmAndGemmBatched = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1, 4), // batch + ::testing::Values(63, 128), // m + ::testing::Values(63, 128), // n + ::testing::Values(63, 128), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({2.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(1), // stride_a_mul + ::testing::Values(1), // stride_b_mul + ::testing::Values(1) // stride_c_mul +); +GENERATE_CPLXGEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, + CplxDefaultGemmAndGemmBatched); + +template +const auto CplxAllStridedBatched = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(3), // batch + ::testing::Values(128), // m + ::testing::Values(128), // n + ::testing::Values(128), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({2.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(2), // lda_mul + ::testing::Values(3), // ldb_mul + ::testing::Values(4), // ldc_mul + ::testing::Values(0, 1, 2), // stride_a_mul + ::testing::Values(0, 1, 2), // stride_b_mul + ::testing::Values(1, 2, 3) // stride_c_mul +); +GENERATE_CPLXGEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, CplxAllStridedBatched); +#endif diff --git a/test/unittest/blas3/blas3_gemm_common.hpp b/test/unittest/blas3/blas3_gemm_common.hpp index 48bd28128..b9bd04e04 100644 --- a/test/unittest/blas3/blas3_gemm_common.hpp +++ b/test/unittest/blas3/blas3_gemm_common.hpp @@ -37,6 +37,19 @@ using gemm_batched_strided_arguments_t = std::tuple; +#ifdef BLAS_ENABLE_COMPLEX +template +using gemm_cplx_arguments_t = + std::tuple, std::complex, int, int, int, + gemm_batch_type_t>; + +template +using gemm_cplx_batched_strided_arguments_t = + std::tuple, std::complex, int, int, int, int, int, int>; +#endif + // Convert batch_type=strided to interleaved on the host template inline std::vector strided_to_interleaved( @@ -383,4 +396,327 @@ static std::string generate_batched_strided_name( BLAS_REGISTER_TEST_CUSTOM_NAME(test_suite, test_suite##combination, \ verify_gemm, \ gemm_batched_strided_arguments_t, \ - combination, generate_batched_strided_name); \ No newline at end of file + combination, generate_batched_strided_name); + +#ifdef BLAS_ENABLE_COMPLEX + +template +inline void verify_gemm(const gemm_cplx_arguments_t arguments) { + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + std::complex alpha; + std::complex beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + gemm_batch_type_t batch_type; + std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, + ldb_mul, ldc_mul, batch_type) = arguments; + + if (batch > 1 && batch_type == gemm_batch_type_t::interleaved) { + // Interleaved batched gemm unsupported with complex data types + GTEST_SKIP(); + } + + const char ta_str[2] = {transa, '\0'}; + const char tb_str[2] = {transb, '\0'}; + + auto q = make_queue(); + blas::SB_Handle sb_handle(q); + + const index_t lda = ((transa != 'n') ? k : m) * lda_mul; + const index_t ldb = ((transb != 'n') ? n : k) * ldb_mul; + const index_t ldc = m * ldc_mul; + + const index_t size_a = m * k * lda_mul; + const index_t size_b = k * n * ldb_mul; + const index_t size_c = m * n * ldc_mul; + + const index_t buffer_size_a = batch * size_a + offset; + const index_t buffer_size_b = batch * size_b + offset; + const index_t buffer_size_c = batch * size_c + offset; + + std::vector> a_m(buffer_size_a); + std::vector> b_m(buffer_size_b); + std::vector> c_m_gpu(buffer_size_c); + + fill_random(a_m); + fill_random(b_m); + fill_random(c_m_gpu); + std::vector> c_m_cpu = c_m_gpu; + + // Use system blas to create a reference output + for (int i = 0; i < batch; ++i) { + reference_blas::cgemm( + ta_str, tb_str, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a_m.data() + i * size_a + offset), lda, + reinterpret_cast(b_m.data() + i * size_b + offset), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_m_cpu.data() + i * size_c + offset), ldc); + } + + auto m_a_gpu = blas::helper::allocate>( + buffer_size_a, q); + auto m_b_gpu = blas::helper::allocate>( + buffer_size_b, q); + auto m_c_gpu = blas::helper::allocate>( + buffer_size_c, q); + + auto copy_a = blas::helper::copy_to_device( + q, reinterpret_cast*>(a_m.data()), m_a_gpu, + buffer_size_a); + auto copy_b = blas::helper::copy_to_device( + q, reinterpret_cast*>(b_m.data()), m_b_gpu, + buffer_size_b); + auto copy_c = blas::helper::copy_to_device( + q, reinterpret_cast*>(c_m_gpu.data()), m_c_gpu, + buffer_size_c); + + complex_sycl alpha_sycl(alpha); + complex_sycl beta_sycl(beta); + + // portBLAS GEMM implementation + typename blas::SB_Handle::event_t gemm_event; + if (batch == index_t(1)) { + gemm_event = _gemm(sb_handle, transa, transb, m, n, k, alpha_sycl, + m_a_gpu + offset, lda, m_b_gpu + offset, ldb, beta_sycl, + m_c_gpu + offset, ldc, {copy_a, copy_b, copy_c}); + } else { + return; + _gemm_batched(sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, + lda, m_b_gpu + offset, ldb, beta, m_c_gpu + offset, ldc, + batch, batch_type, {copy_a, copy_b, copy_c}); + } + sb_handle.wait(gemm_event); + + auto event = blas::helper::copy_to_host( + q, m_c_gpu, reinterpret_cast*>(c_m_gpu.data()), + buffer_size_c); + sb_handle.wait(event); + + const bool isAlmostEqual = utils::compare_vectors(c_m_gpu, c_m_cpu); + ASSERT_TRUE(isAlmostEqual); + + helper::deallocate(m_a_gpu, q); + helper::deallocate(m_b_gpu, q); + helper::deallocate(m_c_gpu, q); +} + +template +inline void verify_gemm(const gemm_cplx_arguments_t arguments) { + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + std::complex alpha; + std::complex beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + gemm_batch_type_t batch_type; + std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, + ldb_mul, ldc_mul, batch_type) = arguments; + + if (alloc == "usm") { +#ifdef SB_ENABLE_USM + verify_gemm(arguments); +#else + GTEST_SKIP(); +#endif + } else { + verify_gemm(arguments); + } +} + +template +static std::string generate_cplx_name( + const ::testing::TestParamInfo>& info) { + std::string alloc; + int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul; + char transa, transb; + std::complex alpha, beta; + gemm_batch_type_t batchType; + BLAS_GENERATE_NAME(info.param, alloc, offset, batch, m, n, k, transa, transb, + alpha, beta, ldaMul, ldbMul, ldcMul, batchType); +} + +template +inline void verify_gemm( + const gemm_cplx_batched_strided_arguments_t arguments) { + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + std::complex alpha; + std::complex beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + index_t stride_a_mul; + index_t stride_b_mul; + index_t stride_c_mul; + std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, + ldb_mul, ldc_mul, stride_a_mul, stride_b_mul, stride_c_mul) = + arguments; + + const char ta_str[2] = {transa, '\0'}; + const char tb_str[2] = {transb, '\0'}; + + auto q = make_queue(); + blas::SB_Handle sb_handle(q); + + const index_t lda = ((transa != 'n') ? k : m) * lda_mul; + const index_t ldb = ((transb != 'n') ? n : k) * ldb_mul; + const index_t ldc = m * ldc_mul; + + const index_t size_a = m * k * lda_mul; + const index_t size_b = k * n * ldb_mul; + const index_t size_c = m * n * ldc_mul; + + const index_t stride_a = stride_a_mul * size_a; + const index_t stride_b = stride_b_mul * size_b; + const index_t stride_c = stride_c_mul * size_c; + + const index_t buffer_size_a = size_a + (batch - 1) * stride_a + offset; + const index_t buffer_size_b = size_b + (batch - 1) * stride_b + offset; + const index_t buffer_size_c = size_c + (batch - 1) * stride_c + offset; + + std::vector> a_m(buffer_size_a); + std::vector> b_m(buffer_size_b); + std::vector> c_m_gpu(buffer_size_c); + + fill_random(a_m); + fill_random(b_m); + fill_random(c_m_gpu); + std::vector> c_m_cpu = c_m_gpu; + + // Use system blas to create a reference output + for (int i = 0; i < batch; ++i) { + reference_blas::cgemm( + ta_str, tb_str, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a_m.data() + i * stride_a + offset), lda, + reinterpret_cast(b_m.data() + i * stride_b + offset), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_m_cpu.data() + i * stride_c + offset), ldc); + } + + auto m_a_gpu = blas::helper::allocate>( + buffer_size_a, q); + auto m_b_gpu = blas::helper::allocate>( + buffer_size_b, q); + auto m_c_gpu = blas::helper::allocate>( + buffer_size_c, q); + + auto copy_a = blas::helper::copy_to_device( + q, reinterpret_cast*>(a_m.data()), m_a_gpu, + buffer_size_a); + auto copy_b = blas::helper::copy_to_device( + q, reinterpret_cast*>(b_m.data()), m_b_gpu, + buffer_size_b); + auto copy_c = blas::helper::copy_to_device( + q, reinterpret_cast*>(c_m_gpu.data()), m_c_gpu, + buffer_size_c); + + complex_sycl alpha_sycl(alpha); + complex_sycl beta_sycl(beta); + + // portBLAS GEMM STRIDED BATCHED implementation + auto gemm_batched_event = _gemm_strided_batched( + sb_handle, transa, transb, m, n, k, alpha_sycl, m_a_gpu + offset, lda, + stride_a, m_b_gpu + offset, ldb, stride_b, beta_sycl, m_c_gpu + offset, + ldc, stride_c, batch, {copy_a, copy_b, copy_c}); + + sb_handle.wait({gemm_batched_event}); + auto event = blas::helper::copy_to_host( + q, m_c_gpu, reinterpret_cast*>(c_m_gpu.data()), + buffer_size_c); + sb_handle.wait(event); + + const bool isAlmostEqual = + (stride_c_mul == 1) + ? utils::compare_vectors(c_m_gpu, c_m_cpu) + : utils::compare_vectors_strided(c_m_gpu, c_m_cpu, stride_c, size_c); + ASSERT_TRUE(isAlmostEqual); + + helper::deallocate(m_a_gpu, q); + helper::deallocate(m_b_gpu, q); + helper::deallocate(m_c_gpu, q); +} + +template +inline void verify_gemm( + const gemm_cplx_batched_strided_arguments_t arguments) { + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + std::complex alpha; + std::complex beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + index_t stride_a_mul; + index_t stride_b_mul; + index_t stride_c_mul; + std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, + ldb_mul, ldc_mul, stride_a_mul, stride_b_mul, stride_c_mul) = + arguments; + + if (alloc == "usm") { +#ifdef SB_ENABLE_USM + verify_gemm(arguments); +#endif + } else { + verify_gemm(arguments); + } +} + +template +static std::string generate_cplx_batched_strided_name( + const ::testing::TestParamInfo>& + info) { + std::string alloc; + int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul, stride_a_mul, + stride_b_mul, stride_c_mul; + char transa, transb; + std::complex alpha, beta; + BLAS_GENERATE_NAME(info.param, alloc, offset, batch, m, n, k, transa, transb, + alpha, beta, ldaMul, ldbMul, ldcMul, stride_a_mul, + stride_b_mul, stride_c_mul); +} + +/** Registers GEMM test for all supported complex data types + * @param test_suite Name of the test suite + * @param combination Combinations object + * @see BLAS_REGISTER_TEST_CUSTOM_NAME + */ +#define GENERATE_CPLX_GEMM_TEST(test_suite, combination) \ + BLAS_REGISTER_CPLX_TEST_CUSTOM_NAME(test_suite, test_suite##combination, \ + verify_gemm, gemm_cplx_arguments_t, \ + combination, generate_cplx_name); + +#define GENERATE_CPLXGEMM_STRIDED_BATCHED_TEST(test_suite, combination) \ + BLAS_REGISTER_CPLX_TEST_CUSTOM_NAME( \ + test_suite, test_suite##combination, verify_gemm, \ + gemm_cplx_batched_strided_arguments_t, combination, \ + generate_cplx_batched_strided_name); + +#endif diff --git a/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp b/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp index 5e156b7c5..95abb271a 100644 --- a/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp +++ b/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp @@ -101,3 +101,81 @@ const auto OffsetNonZero = ::testing::Combine( ::testing::Values(gemm_batch_type_t::strided) // batch_type ); GENERATE_GEMM_TEST(TallSkinnyGemm, OffsetNonZero); + +#ifdef BLAS_ENABLE_COMPLEX +template +const auto CplxBetaNonZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(7, 65), // m + ::testing::Values(9, 126), // n + ::testing::Values(2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 1.5}), // alpha + ::testing::Values>({0.5, 0.5}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(TallSkinnyGemm, CplxBetaNonZeroLDMatch); + +template +const auto CplxBetaNonZeroLDMultiplied = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(7, 33), // m + ::testing::Values(9, 63), // n + ::testing::Values(1026), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 0.5}), // alpha + ::testing::Values>({0.5, 1.5}), // beta + ::testing::Values(2), // lda_mul + ::testing::Values(3), // ldb_mul + ::testing::Values(4), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(TallSkinnyGemm, CplxBetaNonZeroLDMultiplied); + +template +const auto CplxBetaZero = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(7), // m + ::testing::Values(9), // n + ::testing::Values(1026), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 2.0}), // alpha + ::testing::Values>({0.0, 0.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(TallSkinnyGemm, CplxBetaZero); + +template +const auto CplxOffsetNonZero = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(10), // offset + ::testing::Values(1), // batch + ::testing::Values(7), // m + ::testing::Values(9), // n + ::testing::Values(1026), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 2.5}), // alpha + ::testing::Values>({0.5, 1.5}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(TallSkinnyGemm, CplxOffsetNonZero); +#endif diff --git a/test/unittest/blas3/blas3_gemm_test.cpp b/test/unittest/blas3/blas3_gemm_test.cpp index e5d4a4122..f7cae4630 100644 --- a/test/unittest/blas3/blas3_gemm_test.cpp +++ b/test/unittest/blas3/blas3_gemm_test.cpp @@ -139,3 +139,121 @@ const auto LargeBetaNonZeroLDMatch = ::testing::Combine( ::testing::Values(gemm_batch_type_t::strided) // batch_type ); GENERATE_GEMM_TEST(Gemm, LargeBetaNonZeroLDMatch); + +#ifdef BLAS_ENABLE_COMPLEX +template +const auto CplxSmallBetaNonZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 33), // m + ::testing::Values(11, 33), // n + ::testing::Values(16, 17), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxSmallBetaNonZeroLDMatch); + +template +const auto CplxSmallBetaZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 32), // m + ::testing::Values(11, 32), // n + ::testing::Values(17), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 1.0}), // alpha + ::testing::Values>({0.0, 0.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxSmallBetaZeroLDMatch); + +template +const auto CplxSmallBetaZeroLDMultiplied = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 33), // m + ::testing::Values(11, 33), // n + ::testing::Values(17), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 3.0}), // alpha + ::testing::Values>({0.0, 0.0}), // beta + ::testing::Values(2), // lda_mul + ::testing::Values(2), // ldb_mul + ::testing::Values(3), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxSmallBetaZeroLDMultiplied); + +template +const auto CplxAlphaZero = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 10), // offset + ::testing::Values(1), // batch + ::testing::Values(16), // m + ::testing::Values(16), // n + ::testing::Values(17), // k + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values>({0.0, 0.0}), // alpha + ::testing::Values(std::complex{0.0, 0.0}, + std::complex{1.0, 0.0}), // beta + ::testing::Values(1, 2), // lda_mul + ::testing::Values(1, 2), // ldb_mul + ::testing::Values(1, 2), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxAlphaZero); + +template +const auto CplxOffsetNonZero = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(1, 10), // offset + ::testing::Values(1), // batch + ::testing::Values(16, 63), // m + ::testing::Values(16, 63), // n + ::testing::Values(17, 63), // k + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values>({1.0, 1.0}), // alpha + ::testing::Values>({1.0, 1.0}), // beta + ::testing::Values(1, 2), // lda_mul + ::testing::Values(1, 2), // ldb_mul + ::testing::Values(1, 2), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxOffsetNonZero); + +template +const auto CplxLargeBetaNonZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(63, 253), // m + ::testing::Values(63, 253), // n + ::testing::Values(63, 253), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.0, 1.5}), // alpha + ::testing::Values>({1.5, 1.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxLargeBetaNonZeroLDMatch); + +#endif