Skip to content

Commit

Permalink
Add benchmark counters (#128)
Browse files Browse the repository at this point in the history
* Refactor benchmark configurations

* Rebase

* Unify benchmarks into a single execution.

* Add counters and fix warnings and precision issues

* Merge

---------

Co-authored-by: atharva.dubey <[email protected]>
  • Loading branch information
aacostadiaz and AD2605 authored Oct 8, 2024
1 parent e228757 commit d4f99c9
Showing 1 changed file with 43 additions and 19 deletions.
62 changes: 43 additions & 19 deletions benchmarks/benchmark_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,46 +283,70 @@ struct BenchmarkRunnerGemm {
state.SkipWithError("Disposition Failed.");
}

auto tflop = ((2.0 * options.m * options.n * options.k * options.l) * 1e-12);
auto giga_bytes_transferred = (((options.m * options.k) * sizeof(ElementA) +
(options.k * options.n) * sizeof(ElementB) +
(options.beta != 0 ? 2 : 1) * (options.m * options.n) * sizeof(ElementC)) * 1e-9) *
options.l;
state.counters["m"] = options.m;
state.counters["n"] = options.n;
state.counters["k"] = options.k;
state.counters["l"] = options.l;
state.counters["alpha"] = options.alpha;
state.counters["beta"] = options.beta;

std::stringstream extra_label;
if constexpr (cute::size<0>(StrideA{}) == 1) {
extra_label << "layoutA=ColumnMajor ";
} else if constexpr (cute::size<1>(StrideA{}) == 1) {
extra_label << "layoutA=RowMajor ";
}
if constexpr (cute::size<0>(StrideB{}) == 1) {
extra_label << "layoutB=RowMajor ";
} else if constexpr (cute::size<1>(StrideB{}) == 1) {
extra_label << "layoutB=ColumnMajor ";
}
if constexpr (cute::size<0>(StrideC{}) == 1) {
extra_label << "layoutC=ColumnMajor ";
} else if constexpr (cute::size<1>(StrideC{}) == 1) {
extra_label << "layoutC=RowMajor ";
}
state.SetLabel(extra_label.str());

auto gflop = 2.0 * options.m * options.n * options.k * options.l * 1e-9;
auto mega_bytes_transferred = static_cast<double>(
options.m * options.k * sizeof(ElementA) +
options.k * options.n * sizeof(ElementB) +
(options.beta != 0 ? 2 : 1) * options.m * options.n * sizeof(ElementC)
) * 1e-6 * options.l;

initialize_counters(state);
for(auto _ : state) {
GPU_Clock timer;
timer.start();
gemm_op.run();
auto ms_elapsed = timer.milliseconds();
update_counters(state, ms_elapsed, tflop, giga_bytes_transferred);
update_counters(state, ms_elapsed);
state.SetIterationTime(ms_elapsed / 1000);
}
finalize_counters(state, tflop, giga_bytes_transferred);
finalize_counters(state, gflop, mega_bytes_transferred);
}

private:
static void initialize_counters(::benchmark::State& state) {
state.counters["avg_runtime_ms"] = 0;
state.counters["avg_tflops"] = 0;
state.counters["avg_throughput"] = 0;
state.counters["best_runtime_ms"] = std::numeric_limits<double>::max();
}

static void update_counters(::benchmark::State& state, double ms_elapsed, double tflop, double giga_bytes_transferred) {
static void update_counters(::benchmark::State& state, double ms_elapsed) {
state.PauseTiming();
state.counters["avg_runtime_ms"] += ms_elapsed;
state.counters["avg_tflops"] += tflop / (ms_elapsed * 1000);
state.counters["avg_throughput"] += giga_bytes_transferred / (ms_elapsed / 1000);
state.counters["total_runtime_ms"] += ms_elapsed;
state.counters["best_runtime_ms"] = std::min<double>(state.counters["best_runtime_ms"], ms_elapsed);
state.ResumeTiming();
}

static void finalize_counters(::benchmark::State& state, double tflop, double giga_bytes_transferred) {
state.counters["avg_runtime_ms"] /= state.iterations();
state.counters["avg_tflops"] /= state.iterations();
state.counters["avg_throughput"] /= state.iterations();
state.counters["best_tflop"] = tflop / (state.counters["best_runtime_ms"] / 1000);
state.counters["best_bandwidth"] = giga_bytes_transferred / (state.counters["best_runtime_ms"] / 1000);
static void finalize_counters(::benchmark::State& state, double gflop, double mega_bytes_transferred) {
state.counters["avg_runtime_ms"] =
state.counters["total_runtime_ms"] / static_cast<double>(state.iterations());
state.counters["avg_tflops"] = gflop / state.counters["avg_runtime_ms"];
state.counters["avg_throughput"] = mega_bytes_transferred / state.counters["avg_runtime_ms"];
state.counters["best_tflop"] = gflop / state.counters["best_runtime_ms"];
state.counters["best_bandwidth"] = mega_bytes_transferred / state.counters["best_runtime_ms"];
}
};

Expand Down

0 comments on commit d4f99c9

Please sign in to comment.