From 40f95fb3082d241c1e05e012e7df8a9de91ea93c Mon Sep 17 00:00:00 2001 From: nscipione Date: Mon, 26 Jun 2023 14:20:22 +0100 Subject: [PATCH] [benchmark] Refactor state counters --- benchmark/syclblas/extension/omatcopy.cpp | 12 +-- .../common/blas_extension_state_counters.hpp | 79 +++++++++++++++++++ common/include/common/common_utils.hpp | 1 + 3 files changed, 83 insertions(+), 9 deletions(-) create mode 100644 common/include/common/blas_extension_state_counters.hpp diff --git a/benchmark/syclblas/extension/omatcopy.cpp b/benchmark/syclblas/extension/omatcopy.cpp index 4394ceabb..9f301a121 100644 --- a/benchmark/syclblas/extension/omatcopy.cpp +++ b/benchmark/syclblas/extension/omatcopy.cpp @@ -51,15 +51,9 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, int ti, const auto size_a = lda * n; const auto size_b = ldb * ((*t_str == 't') ? m : n); - blas_benchmark::utils::init_level_1_counters< - blas_benchmark::utils::Level1Op::copy, scalar_t>(state, 2 * m * n); - - state.counters["n_fl_ops"] = static_cast(m * n); - state.counters["lda_m"] = (double)lda_mul; - state.counters["ldb_m"] = (double)ldb_mul; - state.counters["trans"] = (double)((*t_str == 't') ? 1 : 0); - state.counters["m"] = (double)m; - state.counters["n"] = (double)n; + blas_benchmark::utils::init_extension_counters< + blas_benchmark::utils::ExtensionOP::omatcopy, scalar_t>( + state, t_str, m, n, lda_mul, ldb_mul); blas::SB_Handle& sb_handle = *sb_handle_ptr; diff --git a/common/include/common/blas_extension_state_counters.hpp b/common/include/common/blas_extension_state_counters.hpp new file mode 100644 index 000000000..dc411e033 --- /dev/null +++ b/common/include/common/blas_extension_state_counters.hpp @@ -0,0 +1,79 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCL-BLAS: BLAS implementation using SYCL + * + * @filename blas_extension_state_counters.hpp + * + **************************************************************************/ + +#ifndef COMMON_BLAS_EXTENSION_STATE_COUNTERS +#define COMMON_BLAS_EXTENSION_STATE_COUNTERS + +namespace blas_benchmark { +namespace utils { + +enum class ExtensionOP : int { + omatcopy = 0, + imatcopy = 1, + omatadd = 2, + omatcopy_batch = 3, + imatcopy_batch = 4, + omatadd_batch = 5 +}; + +template +inline typename std::enable_if::type +init_extension_counters(benchmark::State& state, const char* trans, index_t m, + index_t n, index_t lda_mul, index_t ldb_mul) { + // Google-benchmark counters are double. + double size_d = static_cast(m * n); + state.counters["m"] = static_cast(m); + state.counters["n"] = static_cast(n); + state.counters["n_fl_ops"] = size_d; + state.counters["lda_m"] = static_cast(lda_mul); + state.counters["ldb_m"] = static_cast(ldb_mul); + state.counters["trans"] = static_cast((*trans == 't') ? 1 : 0); + state.counters["bytes_processed"] = (2 * size_d + 1) * sizeof(scalar_t); + return; +} + +template +inline typename std::enable_if::type +init_extension_counters(benchmark::State& state, const char* t_a, + const char* t_b, index_t m, index_t n, index_t lda_mul, + index_t ldb_mul, index_t ldc_mul) { + // Google-benchmark counters are double. + double size_d = static_cast(m * n); + state.counters["m"] = static_cast(m); + state.counters["n"] = static_cast(n); + state.counters["n_fl_ops"] = 3 * static_cast(m * n); + state.counters["lda_m"] = static_cast(lda_mul); + state.counters["ldb_m"] = static_cast(ldb_mul); + state.counters["ldc_m"] = static_cast(ldc_mul); + state.counters["trans_a"] = static_cast((*t_a == 't') ? 1 : 0); + state.counters["trans_b"] = static_cast((*t_b == 't') ? 1 : 0); + state.counters["bytes_processed"] = (3 * size_d + 1) * sizeof(scalar_t); + return; +} +} // namespace utils +} // namespace blas_benchmark + +#endif // COMMON_BLAS_EXTENSION_STATE_COUNTERS diff --git a/common/include/common/common_utils.hpp b/common/include/common/common_utils.hpp index 6bc4b03dd..66f9af37f 100644 --- a/common/include/common/common_utils.hpp +++ b/common/include/common/common_utils.hpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include