diff --git a/Project.toml b/Project.toml index 9b028987..0de0a0d0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ClimaTimeSteppers" uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79" authors = ["Climate Modeling Alliance"] -version = "0.7.37" +version = "0.7.38" [deps] ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" diff --git a/perf/jet.jl b/perf/jet.jl index ad230cb4..9a78c1c5 100644 --- a/perf/jet.jl +++ b/perf/jet.jl @@ -60,5 +60,5 @@ end JET.@test_opt CTS.step_u!(integrator, integrator.cache) CTS.__step!(integrator) # compile first, and make sure it runs - JET.@test_opt broken = true CTS.__step!(integrator) + JET.@test_opt CTS.__step!(integrator) end diff --git a/src/integrators.jl b/src/integrators.jl index 66bd8441..09345304 100644 --- a/src/integrators.jl +++ b/src/integrators.jl @@ -225,6 +225,17 @@ is_past_t(integrator, t) = tdir(integrator) * (t - integrator.t) < zero(integrat reached_tstop(integrator, tstop, stop_at_tstop = integrator.dtchangeable) = integrator.t == tstop || (!stop_at_tstop && is_past_t(integrator, tstop)) + +@inline unrolled_foreach(::Tuple{}, integrator) = nothing +@inline unrolled_foreach(callback, integrator) = + callback.condition(integrator.u, integrator.t, integrator) ? callback.affect!(integrator) : nothing +@inline unrolled_foreach(discrete_callbacks::Tuple{Any}, integrator) = + unrolled_foreach(first(discrete_callbacks), integrator) +@inline function unrolled_foreach(discrete_callbacks::Tuple, integrator) + unrolled_foreach(first(discrete_callbacks), integrator) + unrolled_foreach(Base.tail(discrete_callbacks), integrator) +end + function __step!(integrator) (; _dt, dtchangeable, tstops) = integrator @@ -246,13 +257,7 @@ function __step!(integrator) # apply callbacks discrete_callbacks = integrator.callback.discrete_callbacks - for (ncb, callback) in enumerate(discrete_callbacks) - if callback.condition(integrator.u, integrator.t, integrator)::Bool - NVTX.@range "Callback $ncb of $(length(discrete_callbacks))" color = colorant"yellow" begin - callback.affect!(integrator) - end - end - end + unrolled_foreach(discrete_callbacks, integrator) # remove tstops that were just reached while !isempty(tstops) && reached_tstop(integrator, first(tstops))