Skip to content

Commit

Permalink
Use ClimaTimesteppers benchmark_step
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed May 4, 2024
1 parent 0bd9851 commit 72e1391
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 111 deletions.
8 changes: 8 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,14 @@ steps:
agents:
slurm_mem: 24GB

- label: ":computer: Benchmark: perf target (gpu)"
command: >
julia --color=yes --project=perf perf/benchmark.jl
$PERF_CONFIG_PATH/bm_perf_target.yml
agents:
slurm_mem: 24GB
slurm_gpus: 1

- label: ":computer: Benchmark: perf target (Threaded)"
command: >
julia --color=yes --threads 8 --project=perf perf/benchmark.jl
Expand Down
52 changes: 13 additions & 39 deletions perf/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,26 @@ import ClimaAtmos as CA

include("common.jl")

using CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables # needed for CTS.benchmark_step
using Test
using ClimaComms
import SciMLBase
import ClimaTimeSteppers as CTS

length(ARGS) != 1 && error("Usage: benchmark.jl <config_file>")
config_file = ARGS[1]
config_dict = YAML.load_file(config_file)
config = AtmosCoveragePerfConfig(config_dict)
config = AtmosCoveragePerfConfig(config_dict);

simulation = CA.get_simulation(config)
(; integrator) = simulation
simulation = CA.get_simulation(config);
(; integrator) = simulation;
(; parsed_args) = config;

(; parsed_args) = config
device = ClimaComms.device()
(; table_summary, trials) = CTS.benchmark_step(integrator, device)

import SciMLBase
import ClimaTimeSteppers as CTS
SciMLBase.step!(integrator) # compile first

(; sol, u, p, dt, t) = integrator

W = get_W(integrator)
X = similar(u)

include("benchmark_utils.jl")

import OrderedCollections
import LinearAlgebra as LA
trials = OrderedCollections.OrderedDict()
#! format: off
trials["Wfact"] = get_trial(wfact_fun(integrator), (W, u, p, dt, t), "Wfact");
trials["linsolve"] = get_trial(LA.ldiv!, (X, W, u), "linsolve");
trials["implicit_tendency!"] = get_trial(implicit_fun(integrator), implicit_args(integrator), "implicit_tendency!");
trials["remaining_tendency!"] = get_trial(remaining_fun(integrator), remaining_args(integrator), "remaining_tendency!");
trials["additional_tendency!"] = get_trial(CA.additional_tendency!, (X, u, p, t), "additional_tendency!");
trials["hyperdiffusion_tendency!"] = get_trial(CA.hyperdiffusion_tendency!, remaining_args(integrator), "hyperdiffusion_tendency!");
trials["dss!"] = get_trial(CA.dss!, (u, p, t), "dss!");
trials["set_precomputed_quantities!"] = get_trial(CA.set_precomputed_quantities!, (u, p, t), "set_precomputed_quantities!");
trials["step!"] = get_trial(SciMLBase.step!, (integrator, ), "step!");
#! format: on

using Test
using ClimaComms

table_summary = OrderedCollections.OrderedDict()
for k in keys(trials)
table_summary[k] = get_summary(trials[k])
end
tabulate_summary(table_summary)

are_boundschecks_forced = Base.JLOptions().check_bounds == 1
# Benchmark allocation tests
@testset "Benchmark allocation tests" begin
Expand Down Expand Up @@ -81,7 +56,6 @@ if get(ENV, "BUILDKITE", "") == "true"
end
end

import ClimaComms
if config.comms_ctx isa ClimaComms.SingletonCommsContext && !isinteractive()
include(joinpath(pkgdir(CA), "perf", "jet_report_nfailures.jl"))
end
end
72 changes: 0 additions & 72 deletions perf/benchmark_utils.jl

This file was deleted.

0 comments on commit 72e1391

Please sign in to comment.