diff --git a/Project.toml b/Project.toml index f670d14e..086978c6 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/src/functions.jl b/src/functions.jl index 04c92f4d..08cac1eb 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -4,10 +4,11 @@ export ClimaODEFunction, ForwardEulerODEFunction abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end -struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction +struct ClimaODEFunction{TEL, TL, TE, TES, TI, L, D, PE, PI} <: AbstractClimaODEFunction T_exp_T_lim!::TEL T_lim!::TL T_exp!::TE + T_stoch!::TES T_imp!::TI lim!::L dss!::D @@ -17,13 +18,14 @@ struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFuncti T_exp_T_lim! = nothing, # nothing or (uₜ_exp, uₜ_lim, u, p, t) -> ... T_lim! = nothing, # nothing or (uₜ, u, p, t) -> ... T_exp! = nothing, # nothing or (uₜ, u, p, t) -> ... + T_stoch! = nothing, # nothing or (uₜ, u, p, t) -> ... T_imp! = nothing, # nothing or (uₜ, u, p, t) -> ... lim! = (u, p, t, u_ref) -> nothing, dss! = (u, p, t) -> nothing, post_explicit! = (u, p, t) -> nothing, post_implicit! = (u, p, t) -> nothing, ) - args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!) + args = (T_exp_T_lim!, T_lim!, T_exp!, T_stoch!, T_imp!, lim!, dss!, post_explicit!, post_implicit!) if !isnothing(T_exp_T_lim!) @assert isnothing(T_exp!) "`T_exp_T_lim!` was passed, `T_exp!` must be `nothing`" @@ -37,6 +39,7 @@ struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFuncti end has_T_exp(f::ClimaODEFunction) = !isnothing(f.T_exp!) || !isnothing(f.T_exp_T_lim!) +has_T_stoch(f::ClimaODEFunction) = !isnothing(f.T_stoch!) has_T_lim(f::ClimaODEFunction) = !isnothing(f.lim!) && (!isnothing(f.T_lim!) || !isnothing(f.T_exp_T_lim!)) # Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work). diff --git a/src/solvers/imex_ark.jl b/src/solvers/imex_ark.jl index 4c2d24f9..36b15ec5 100644 --- a/src/solvers/imex_ark.jl +++ b/src/solvers/imex_ark.jl @@ -1,4 +1,5 @@ import NVTX +# using Random # for testing has_jac(T_imp!) = hasfield(typeof(T_imp!), :Wfact) && @@ -17,6 +18,7 @@ struct IMEXARKCache{SCU, SCE, SCI, T, Γ, NMC} U::SCU # sparse container of length s T_lim::SCE # sparse container of length s T_exp::SCE # sparse container of length s + T_stoch::SCE # sparse container of length s T_imp::SCI # sparse container of length s temp::T γ::Γ @@ -35,6 +37,7 @@ function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{Unco U = zero(u0) T_lim = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_exp))), inds_T_exp) T_exp = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_exp))), inds_T_exp) + T_stoch = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_exp))), inds_T_exp) T_imp = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_imp))), inds_T_imp) temp = zero(u0) γs = unique(filter(!iszero, diag(a_imp))) @@ -42,7 +45,7 @@ function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{Unco jac_prototype = has_jac(T_imp!) ? T_imp!.jac_prototype : nothing newtons_method_cache = isnothing(T_imp!) || isnothing(newtons_method) ? nothing : allocate_cache(newtons_method, u0, jac_prototype) - return IMEXARKCache(U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) + return IMEXARKCache(U, T_lim, T_exp, T_stoch, T_imp, temp, γ, newtons_method_cache) end # generic fallback @@ -50,10 +53,10 @@ function step_u!(integrator, cache::IMEXARKCache) (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob (; post_explicit!, post_implicit!) = f - (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f + (; T_lim!, T_exp!, T_stoch!, T_imp!, lim!, dss!) = f (; tableau, newtons_method) = alg (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau - (; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache + (; U, T_lim, T_exp, T_stoch, T_imp, temp, γ, newtons_method_cache) = cache s = length(b_exp) if !isnothing(T_imp!) && !isnothing(newtons_method) @@ -69,6 +72,7 @@ function step_u!(integrator, cache::IMEXARKCache) end update_stage!(integrator, cache, ntuple(i -> i, Val(s))) + return @. u = U # hack to get SDE solver to work t_final = t + dt @@ -80,6 +84,9 @@ function step_u!(integrator, cache::IMEXARKCache) # Update based on tendencies from previous stages has_T_exp(f) && fused_increment!(u, dt, b_exp, T_exp, Val(s)) + + has_T_stoch(f) && fused_increment!(u, dt, b_exp, T_stoch, Val(s)) + isnothing(T_imp!) || fused_increment!(u, dt, b_imp, T_imp, Val(s)) dss!(u, p, t_final) @@ -99,15 +106,17 @@ end (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob (; post_explicit!, post_implicit!) = f - (; T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!) = f + (; T_exp_T_lim!, T_lim!, T_exp!, T_stoch!, T_imp!, lim!, dss!) = f (; tableau, newtons_method) = alg (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau - (; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache + (; U, T_lim, T_exp, T_stoch, T_imp, temp, γ, newtons_method_cache) = cache s = length(b_exp) t_exp = t + dt * c_exp[i] t_imp = t + dt * c_imp[i] + + if has_T_lim(f) # Update based on limited tendencies from previous stages assign_fused_increment!(U, u, dt, a_exp, T_lim, Val(i)) i ≠ 1 && lim!(U, p, t_exp, u) @@ -116,6 +125,11 @@ end end # Update based on tendencies from previous stages + if i ≠ 1 && has_T_stoch(f) + dW = √dt * randn(eltype(dt)) + push!(p.dW, dW) + fused_increment!(U, dW, a_exp, T_stoch, Val(i)) + end has_T_exp(f) && fused_increment!(U, dt, a_exp, T_exp, Val(i)) isnothing(T_imp!) || fused_increment!(U, dt, a_imp, T_imp, Val(i)) @@ -175,6 +189,7 @@ end isnothing(T_lim!) || T_lim!(T_lim[i], U, p, t_exp) isnothing(T_exp!) || T_exp!(T_exp[i], U, p, t_exp) end + isnothing(T_stoch!) || T_stoch!(T_stoch[i], U, p, t_exp) end return nothing diff --git a/test/Project.toml b/test/Project.toml index 544b18ba..1c369b42 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -21,6 +22,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/problems.jl b/test/problems.jl index fd12df3d..34d13a18 100644 --- a/test/problems.jl +++ b/test/problems.jl @@ -5,6 +5,120 @@ import ClimaCore: Domains, Geometry, Meshes, Topologies, Spaces, Fields, Operato import Krylov # Trigger ClimaCore/ext/KrylovExt +function geometric_BM() + f!(uₜ, u, p, t) = @. uₜ = p.μ * u + g!(uₜ, u, p, t) = @. uₜ = p.σ * u + ClimaODEFunction(; T_exp! = f!, T_stoch! = g!) +end + +function geometric_BM_prob(; + u0 = [1 / 2], + p = (; μ = 1.01, σ = 0.5, dW = Float64[]), + tspan = (0.0, 1.0), + ) + func = geometric_BM() + ODEProblem(func, u0, tspan, p) +end + +geometric_BM_analytical(u0, p, t, W) = @. u0 * exp((p.μ - p.σ^2 / 2) * t + p.σ * W) +geometric_BM_analytical(sol) = begin + u0 = sol.u[1][1] + p = sol.prob.p + W = [0; cumsum(p.dW)] + geometric_BM_analytical(u0, p, sol.t, W) +end + +function OU() + f!(uₜ, u, p, t) = @. uₜ = p.μ * (p.θ - u) + g!(uₜ, u, p, t) = @. uₜ = p.σ + ClimaODEFunction(; T_exp! = f!, T_stoch! = g!) +end + +function OU_prob(; + u0 = [1 / 2], + p = (; μ = 1.01, θ = 1.0, σ = 1.0, dW = Float64[]), + tspan = (0.0, 1.0), + ) + func = OU() + ODEProblem(func, u0, tspan, p) +end + +function OU_AR_process(; uprev, h, p, z = randn()) + (;μ, θ, σ) = p + exp_minus_θh = @. exp(-θ * h) + umean = uprev * exp_minus_θh + μ * (1 - exp_minus_θh) + unoise = √(σ^2 / 2θ * (1 - exp_minus_θh^2)) * z + unext = umean + unoise + unext +end +function OU_AR_process_solve(u0, p, t, z) + @assert length(z) == length(t) - 1 + u = zeros(length(t)) # preallocate + u[1] = u0 + for i in 2:length(t) + u[i] = OU_AR_process(uprev = u[i-1], h = t[i] - t[i-1], p = p, z = z[i-1]) + end + u +end + +function OU_analytical(u0, p, t, dW) + (;μ, θ, σ) = p + exp_minus_θt = @. exp(-θ * t) + exp_θt_dW = @. exp(θ * t) * [0; dW] + ∫exp_θt_dW = cumsum(exp_θt_dW) + @. u0 * exp_minus_θt + μ * (1 - exp_minus_θt) + σ * exp_minus_θt * ∫exp_θt_dW +end +OU_analytical(sol) = begin + u0 = sol.u[1][1] + p = sol.prob.p + OU_analytical(u0, p, sol.t, p.dW) +end + +using GLMakie + +# Example usage: +# ts, sols, an_sols, AR_sols = mysolve(n=10, prob = OU_prob(), analytical = OU_analytical, AR = OU_AR_process_solve) +# plot_sols(ts, sols, an_sols, AR_sols) +function mysolve(n=1; prob, analytical = nothing, AR = nothing, dt = 0.1) + u0 = prob.u0[1] + alg = IMEXAlgorithm(ARS111(), nothing); # Explicit Euler + sols = Vector{Float64}[] + an_sols = Vector{Float64}[] + AR_sols = Vector{Float64}[] + ts = Float64[] + for _ in 1:n + prob.u0[1] = u0 # reinit + sol = solve(prob, alg; dt, save_everystep = true) + ts = sol.t + push!(sols, vcat(sol.u...)) + if !isnothing(analytical) + an_sol = analytical(sol) + push!(an_sols, vcat(an_sol...)) + end + if !isnothing(AR) + z = sol.prob.p.dW / √dt # rescale to AR(1) + AR_sol = AR(u0, prob.p, sol.t, z) + push!(AR_sols, AR_sol) + end + empty!(prob.p.dW) + end + sols = hcat(sols...) + !isnothing(analytical) && (an_sols = hcat(an_sols...)) + !isnothing(AR) && (AR_sols = hcat(AR_sols...)) + return ts, sols, an_sols, AR_sols +end + +function plot_sols(ts, sols, an_sols = nothing, AR_sols = nothing) + fig = Figure() + ax = Axis(fig[1,1]) + lines!.(ax, Ref(ts), eachcol(sols); color=:black, label = "numerical") + isnothing(an_sols) || lines!.(ax, Ref(ts), eachcol(an_sols); color=:red, label = "analytical") + isnothing(AR_sols) || lines!.(ax, Ref(ts), eachcol(AR_sols); color=:blue, label = "AR(1)") + axislegend(; position=:lt, unique=true) + fig +end + + """ Single variable linear ODE @@ -18,7 +132,7 @@ u(t) = u_0 e^{αt} This is an in-place variant of the one from DiffEqProblemLibrary.jl. """ -function linear_prob() +function linear_prob() ## nice thing ODEProblem( IncrementingODEFunction{true}((du, u, p, t, α = true, β = false) -> (du .= α .* p .* u .+ β .* du)), [1 / 2], @@ -539,7 +653,7 @@ function climacore_1Dheat_test_cts(::Type{FT}) where {FT} return state end - tendency_func = ClimaODEFunction(; T_exp!) + tendency_func = ClimaODEFunction(; T_exp!) # example of explicit-only ClimaODEFunction split_tendency_func = tendency_func make_prob(func) = ODEProblem(func, init_state, (FT(0), t_end), nothing) IntegratorTestCase(