From a9d2adf2779f3f6699ee8766a233a5f0dfcc36dd Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 3 May 2024 14:30:43 -0400 Subject: [PATCH] Use ClimaTimeSteppers benchmark_step --- .buildkite/pipeline.yml | 8 +++++ perf/benchmark.jl | 46 ++++++-------------------- perf/benchmark_utils.jl | 72 ----------------------------------------- 3 files changed, 18 insertions(+), 108 deletions(-) delete mode 100644 perf/benchmark_utils.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 4e12224c5a7..679c21ed5ff 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -820,6 +820,14 @@ steps: agents: slurm_mem: 24GB + - label: ":computer: Benchmark: perf target (gpu)" + command: > + julia --color=yes --project=perf perf/benchmark.jl + $PERF_CONFIG_PATH/bm_perf_target.yml + agents: + slurm_mem: 24GB + slurm_gpus: 1 + - label: ":computer: Benchmark: perf target (Threaded)" command: > julia --color=yes --threads 8 --project=perf perf/benchmark.jl diff --git a/perf/benchmark.jl b/perf/benchmark.jl index 715385713c5..9cb467cbb10 100644 --- a/perf/benchmark.jl +++ b/perf/benchmark.jl @@ -5,51 +5,26 @@ import ClimaAtmos as CA include("common.jl") +using CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables # needed for CTS.benchmark_step + length(ARGS) != 1 && error("Usage: benchmark.jl ") config_file = ARGS[1] config_dict = YAML.load_file(config_file) -config = AtmosCoveragePerfConfig(config_dict) +config = AtmosCoveragePerfConfig(config_dict); -simulation = CA.get_simulation(config) -(; integrator) = simulation +simulation = CA.get_simulation(config); +(; integrator) = simulation; +(; parsed_args) = config; -(; parsed_args) = config +device = ClimaComms.device() +(; table_summary, trials) = CTS.benchmark_step(integrator, device) +using Test +using ClimaComms import SciMLBase import ClimaTimeSteppers as CTS SciMLBase.step!(integrator) # compile first -(; sol, u, p, dt, t) = integrator - -W = get_W(integrator) -X = similar(u) - -include("benchmark_utils.jl") - -import OrderedCollections -import LinearAlgebra as LA -trials = OrderedCollections.OrderedDict() -#! format: off -trials["Wfact"] = get_trial(wfact_fun(integrator), (W, u, p, dt, t), "Wfact"); -trials["linsolve"] = get_trial(LA.ldiv!, (X, W, u), "linsolve"); -trials["implicit_tendency!"] = get_trial(implicit_fun(integrator), implicit_args(integrator), "implicit_tendency!"); -trials["remaining_tendency!"] = get_trial(remaining_fun(integrator), remaining_args(integrator), "remaining_tendency!"); -trials["additional_tendency!"] = get_trial(CA.additional_tendency!, (X, u, p, t), "additional_tendency!"); -trials["hyperdiffusion_tendency!"] = get_trial(CA.hyperdiffusion_tendency!, remaining_args(integrator), "hyperdiffusion_tendency!"); -trials["dss!"] = get_trial(CA.dss!, (u, p, t), "dss!"); -trials["set_precomputed_quantities!"] = get_trial(CA.set_precomputed_quantities!, (u, p, t), "set_precomputed_quantities!"); -trials["step!"] = get_trial(SciMLBase.step!, (integrator, ), "step!"); -#! format: on - -using Test -using ClimaComms - -table_summary = OrderedCollections.OrderedDict() -for k in keys(trials) - table_summary[k] = get_summary(trials[k]) -end -tabulate_summary(table_summary) - are_boundschecks_forced = Base.JLOptions().check_bounds == 1 # Benchmark allocation tests @testset "Benchmark allocation tests" begin @@ -81,7 +56,6 @@ if get(ENV, "BUILDKITE", "") == "true" end end -import ClimaComms if config.comms_ctx isa ClimaComms.SingletonCommsContext && !isinteractive() include(joinpath(pkgdir(CA), "perf", "jet_report_nfailures.jl")) end diff --git a/perf/benchmark_utils.jl b/perf/benchmark_utils.jl deleted file mode 100644 index 73c35814931..00000000000 --- a/perf/benchmark_utils.jl +++ /dev/null @@ -1,72 +0,0 @@ -import StatsBase -import PrettyTables -import BenchmarkTools - -##### -##### BenchmarkTools's trial utils -##### - -get_summary(trial) = (; - # Using some BenchmarkTools internals :/ - mem = BenchmarkTools.prettymemory(trial.memory), - mem_val = trial.memory, - nalloc = trial.allocs, - t_min = BenchmarkTools.prettytime(minimum(trial.times)), - t_max = BenchmarkTools.prettytime(maximum(trial.times)), - t_mean = BenchmarkTools.prettytime(StatsBase.mean(trial.times)), - t_mean_val = StatsBase.mean(trial.times), - t_med = BenchmarkTools.prettytime(StatsBase.median(trial.times)), - n_samples = length(trial), -) - -function tabulate_summary(summary) - summary_keys = collect(keys(summary)) - mem = map(k -> summary[k].mem, summary_keys) - nalloc = map(k -> summary[k].nalloc, summary_keys) - t_mean = map(k -> summary[k].t_mean, summary_keys) - t_min = map(k -> summary[k].t_min, summary_keys) - t_max = map(k -> summary[k].t_max, summary_keys) - t_med = map(k -> summary[k].t_med, summary_keys) - n_samples = map(k -> summary[k].n_samples, summary_keys) - - table_data = hcat( - string.(collect(keys(summary))), - mem, - nalloc, - t_min, - t_max, - t_mean, - t_med, - n_samples, - ) - - header = ( - [ - "Function", - "Memory", - "allocs", - "Time", - "Time", - "Time", - "Time", - "N-samples", - ], - [" ", "estimate", "estimate", "min", "max", "mean", "median", ""], - ) - - PrettyTables.pretty_table( - table_data; - header, - crop = :none, - alignment = vcat(:l, repeat([:r], length(header[1]) - 1)), - ) -end - -function get_trial(f, args, name) - sample_limit = 10 - f(args...) # compile first - b = BenchmarkTools.@benchmarkable $f($(args)...) - println("Benchmarking $name...") - trial = BenchmarkTools.run(b, samples = sample_limit) - return trial -end