diff --git a/include/interface/extension_interface.h b/include/interface/extension_interface.h index 92fa4a41d..f2202d24b 100644 --- a/include/interface/extension_interface.h +++ b/include/interface/extension_interface.h @@ -158,16 +158,37 @@ typename sb_handle_t::event_t _omatcopy2(sb_handle_t& sb_handle, char trans, ld_in, inc_in, out_memory, ld_out, inc_out); } +/** + * \brief Computes scaled addition of two matrices A & B with or without + * transpose and copying results back to an output matrix C. + * + * @tparam sb_handle_t SB_Handle type + * @tparam element_t Undelying element data type of the matrix container + * @tparam index_t Index type + * @tparam container_t Inputs/Output Container Type + * @param trans_a Apply or not matrix transpose to A. + * @param trans_b Apply or not matrix transpose to B. + * @param m Number of rows in output matrix C + * @param n Number of columns in output matrix C + * @param alpha Scaling factor of matrix A + * @param A Container Input matrix A + * @param lda Matrix A leading dimension + * @param beta scaling factor of matrix B + * @param B Container Input matrix B + * @param ldb Matrix B leading dimension + * @param C Container Output matrix C + * @param ldc Matrix C leading dimension + */ template typename sb_handle_t::event_t _omatadd(sb_handle_t& sb_handle, char trans_a, char trans_b, index_t m, index_t n, - element_t alpha, container_t a, + element_t alpha, container_t A, index_t lda, element_t beta, - container_t b, index_t ldb, - container_t c, index_t ldc) { - return internal::_omatadd(sb_handle, trans_a, trans_b, m, n, alpha, a, lda, - beta, b, ldb, c, ldc); + container_t B, index_t ldb, + container_t C, index_t ldc) { + return internal::_omatadd(sb_handle, trans_a, trans_b, m, n, alpha, A, lda, + beta, B, ldb, C, ldc); } /** diff --git a/test/unittest/extension/omatadd_test.cpp b/test/unittest/extension/omatadd_test.cpp index 14678a9b3..6e7d05a19 100644 --- a/test/unittest/extension/omatadd_test.cpp +++ b/test/unittest/extension/omatadd_test.cpp @@ -28,12 +28,29 @@ using index_t = int; namespace reference_blas { -// blas-like extension omatAdd used as wrapper around omatcopy +/** + * @brief Reference omat-add implementation using reference omatcopy. + * + * @param trans_a (char) 'n' or 't' corresponding to non-transposed or + * transposed matrix A respectively. + * @param trans_b (char) 'n' or 't' corresponding to non-transposed or + * transposed matrix B respectively. + * @param m Number of rows in output matrix C + * @param n Number of columns in output matrix C + * @param alpha Scaling factor of matrix A + * @param A (vector) Input matrix A + * @param lda_m Matrix A leading dimension multiplier. (lda = lda_m * A_rows) + * @param beta scaling factor of matrix B + * @param B (vector) Input matrix B + * @param ldb_m Matrix B leading dimension multiplier. (ldb = ldb_m * B_rows) + * @param C (vector) Output matrix C + * @param ldc_m Matrix C leading dimension multiplier. (ldc = ldc_m * C_rows) + */ template void omatadd(const char trans_a, const char trans_b, const index_t m, - const index_t n, const scalar_t alpha, std::vector &a, - const index_t lda_m, const scalar_t beta, std::vector &b, - const index_t ldb_m, std::vector &c, + const index_t n, const scalar_t alpha, std::vector &A, + const index_t lda_m, const scalar_t beta, std::vector &B, + const index_t ldb_m, std::vector &C, const index_t ldc_m) { const index_t a_rows = trans_a == 't' ? n : m; const index_t a_cols = trans_a == 't' ? m : n; @@ -42,19 +59,19 @@ void omatadd(const char trans_a, const char trans_b, const index_t m, index_t ldc = ldc_m * m; - // Temp Matrix 1 for computing a -> alpha * op(A) + // Temp Matrix 1 for computing A -> alpha * op(A) std::vector TempMatrix1(ldc * n, 0); - omatcopy(trans_a, a_rows, a_cols, alpha, a.data(), lda_m * a_rows, + omatcopy(trans_a, a_rows, a_cols, alpha, A.data(), lda_m * a_rows, TempMatrix1.data(), ldc); - // Temp Matrix 2 for computing b -> beta * op(B) + // Temp Matrix 2 for computing B -> beta * op(B) std::vector TempMatrix2(ldc * n, 0); - omatcopy(trans_b, b_rows, b_cols, beta, b.data(), ldb_m * b_rows, + omatcopy(trans_b, b_rows, b_cols, beta, B.data(), ldb_m * b_rows, TempMatrix2.data(), ldc); - // Compute Sum of Temp matrices -> c + // Compute Sum of Temp matrices -> C for (index_t j = 0; j < n; j++) { for (index_t i = 0; i < m; i++) { - c.at(i + j * ldc) = + C.at(i + j * ldc) = TempMatrix1.at(i + j * ldc) + TempMatrix2.at(i + j * ldc); } } @@ -62,7 +79,6 @@ void omatadd(const char trans_a, const char trans_b, const index_t m, } // namespace reference_blas -// Parameters : trans_a, trans_b, m, n, alpha, beta, lda_m, ldb_m, ldc_m template using combination_t = std::tuple; @@ -118,15 +134,16 @@ void run_test(const combination_t combi) { } template -const auto combi = ::testing::Combine(::testing::Values('n', 't'), - ::testing::Values('n', 't'), - ::testing::Values(16, 33, 63), - ::testing::Values(16, 33, 63), - ::testing::Values(0, 1, 2), - ::testing::Values(0, 1, 2), - ::testing::Values(1, 2), - ::testing::Values(1, 2), - ::testing::Values(1, 2, 3)); +const auto combi = + ::testing::Combine(::testing::Values('n', 't'), // trans_a + ::testing::Values('n', 't'), // trans_b + ::testing::Values(64, 129, 255), // m + ::testing::Values(64, 129, 255), // n + ::testing::Values(0, 1, 2), // alpha + ::testing::Values(0, 1, 2), // beta + ::testing::Values(1, 2), // lda_mul + ::testing::Values(1, 2), // ldb_mul + ::testing::Values(1, 2, 3)); // ldc_mul template static std::string generate_name(