Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: sde integrator #265

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

Expand Down
7 changes: 5 additions & 2 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ export ClimaODEFunction, ForwardEulerODEFunction

abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end

struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction
struct ClimaODEFunction{TEL, TL, TE, TES, TI, L, D, PE, PI} <: AbstractClimaODEFunction
T_exp_T_lim!::TEL
T_lim!::TL
T_exp!::TE
T_stoch!::TES
T_imp!::TI
lim!::L
dss!::D
Expand All @@ -17,13 +18,14 @@ struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFuncti
T_exp_T_lim! = nothing, # nothing or (uₜ_exp, uₜ_lim, u, p, t) -> ...
T_lim! = nothing, # nothing or (uₜ, u, p, t) -> ...
T_exp! = nothing, # nothing or (uₜ, u, p, t) -> ...
T_stoch! = nothing, # nothing or (uₜ, u, p, t) -> ...
T_imp! = nothing, # nothing or (uₜ, u, p, t) -> ...
lim! = (u, p, t, u_ref) -> nothing,
dss! = (u, p, t) -> nothing,
post_explicit! = (u, p, t) -> nothing,
post_implicit! = (u, p, t) -> nothing,
)
args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!)
args = (T_exp_T_lim!, T_lim!, T_exp!, T_stoch!, T_imp!, lim!, dss!, post_explicit!, post_implicit!)

if !isnothing(T_exp_T_lim!)
@assert isnothing(T_exp!) "`T_exp_T_lim!` was passed, `T_exp!` must be `nothing`"
Expand All @@ -37,6 +39,7 @@ struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFuncti
end

has_T_exp(f::ClimaODEFunction) = !isnothing(f.T_exp!) || !isnothing(f.T_exp_T_lim!)
has_T_stoch(f::ClimaODEFunction) = !isnothing(f.T_stoch!)
has_T_lim(f::ClimaODEFunction) = !isnothing(f.lim!) && (!isnothing(f.T_lim!) || !isnothing(f.T_exp_T_lim!))

# Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work).
Expand Down
25 changes: 20 additions & 5 deletions src/solvers/imex_ark.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import NVTX
# using Random # for testing

has_jac(T_imp!) =
hasfield(typeof(T_imp!), :Wfact) &&
Expand All @@ -17,6 +18,7 @@ struct IMEXARKCache{SCU, SCE, SCI, T, Γ, NMC}
U::SCU # sparse container of length s
T_lim::SCE # sparse container of length s
T_exp::SCE # sparse container of length s
T_stoch::SCE # sparse container of length s
T_imp::SCI # sparse container of length s
temp::T
γ::Γ
Expand All @@ -35,25 +37,26 @@ function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{Unco
U = zero(u0)
T_lim = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_exp))), inds_T_exp)
T_exp = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_exp))), inds_T_exp)
T_stoch = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_exp))), inds_T_exp)
T_imp = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_imp))), inds_T_imp)
temp = zero(u0)
γs = unique(filter(!iszero, diag(a_imp)))
γ = length(γs) == 1 ? γs[1] : nothing # TODO: This could just be a constant.
jac_prototype = has_jac(T_imp!) ? T_imp!.jac_prototype : nothing
newtons_method_cache =
isnothing(T_imp!) || isnothing(newtons_method) ? nothing : allocate_cache(newtons_method, u0, jac_prototype)
return IMEXARKCache(U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache)
return IMEXARKCache(U, T_lim, T_exp, T_stoch, T_imp, temp, γ, newtons_method_cache)
end

# generic fallback
function step_u!(integrator, cache::IMEXARKCache)
(; u, p, t, dt, alg) = integrator
(; f) = integrator.sol.prob
(; post_explicit!, post_implicit!) = f
(; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
(; T_lim!, T_exp!, T_stoch!, T_imp!, lim!, dss!) = f
(; tableau, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
(; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache
(; U, T_lim, T_exp, T_stoch, T_imp, temp, γ, newtons_method_cache) = cache
s = length(b_exp)

if !isnothing(T_imp!) && !isnothing(newtons_method)
Expand All @@ -69,6 +72,7 @@ function step_u!(integrator, cache::IMEXARKCache)
end

update_stage!(integrator, cache, ntuple(i -> i, Val(s)))
return @. u = U # hack to get SDE solver to work

t_final = t + dt

Expand All @@ -80,6 +84,9 @@ function step_u!(integrator, cache::IMEXARKCache)

# Update based on tendencies from previous stages
has_T_exp(f) && fused_increment!(u, dt, b_exp, T_exp, Val(s))

has_T_stoch(f) && fused_increment!(u, dt, b_exp, T_stoch, Val(s))

isnothing(T_imp!) || fused_increment!(u, dt, b_imp, T_imp, Val(s))

dss!(u, p, t_final)
Expand All @@ -99,15 +106,17 @@ end
(; u, p, t, dt, alg) = integrator
(; f) = integrator.sol.prob
(; post_explicit!, post_implicit!) = f
(; T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!) = f
(; T_exp_T_lim!, T_lim!, T_exp!, T_stoch!, T_imp!, lim!, dss!) = f
(; tableau, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
(; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache
(; U, T_lim, T_exp, T_stoch, T_imp, temp, γ, newtons_method_cache) = cache
s = length(b_exp)

t_exp = t + dt * c_exp[i]
t_imp = t + dt * c_imp[i]



if has_T_lim(f) # Update based on limited tendencies from previous stages
assign_fused_increment!(U, u, dt, a_exp, T_lim, Val(i))
i ≠ 1 && lim!(U, p, t_exp, u)
Expand All @@ -116,6 +125,11 @@ end
end

# Update based on tendencies from previous stages
if i ≠ 1 && has_T_stoch(f)
dW = √dt * randn(eltype(dt))
push!(p.dW, dW)
fused_increment!(U, dW, a_exp, T_stoch, Val(i))
end
has_T_exp(f) && fused_increment!(U, dt, a_exp, T_exp, Val(i))
isnothing(T_imp!) || fused_increment!(U, dt, a_imp, T_imp, Val(i))

Expand Down Expand Up @@ -175,6 +189,7 @@ end
isnothing(T_lim!) || T_lim!(T_lim[i], U, p, t_exp)
isnothing(T_exp!) || T_exp!(T_exp[i], U, p, t_exp)
end
isnothing(T_stoch!) || T_stoch!(T_stoch[i], U, p, t_exp)
end

return nothing
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -21,6 +22,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand Down
118 changes: 116 additions & 2 deletions test/problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,120 @@ import ClimaCore: Domains, Geometry, Meshes, Topologies, Spaces, Fields, Operato

import Krylov # Trigger ClimaCore/ext/KrylovExt

function geometric_BM()
f!(uₜ, u, p, t) = @. uₜ = p.μ * u
g!(uₜ, u, p, t) = @. uₜ = p.σ * u
ClimaODEFunction(; T_exp! = f!, T_stoch! = g!)
end

function geometric_BM_prob(;
u0 = [1 / 2],
p = (; μ = 1.01, σ = 0.5, dW = Float64[]),
tspan = (0.0, 1.0),
)
func = geometric_BM()
ODEProblem(func, u0, tspan, p)
end

geometric_BM_analytical(u0, p, t, W) = @. u0 * exp((p.μ - p.σ^2 / 2) * t + p.σ * W)
geometric_BM_analytical(sol) = begin
u0 = sol.u[1][1]
p = sol.prob.p
W = [0; cumsum(p.dW)]
geometric_BM_analytical(u0, p, sol.t, W)
end

function OU()
f!(uₜ, u, p, t) = @. uₜ = p.μ * (p.θ - u)
g!(uₜ, u, p, t) = @. uₜ = p.σ
ClimaODEFunction(; T_exp! = f!, T_stoch! = g!)
end

function OU_prob(;
u0 = [1 / 2],
p = (; μ = 1.01, θ = 1.0, σ = 1.0, dW = Float64[]),
tspan = (0.0, 1.0),
)
func = OU()
ODEProblem(func, u0, tspan, p)
end

function OU_AR_process(; uprev, h, p, z = randn())
(;μ, θ, σ) = p
exp_minus_θh = @. exp(-θ * h)
umean = uprev * exp_minus_θh + μ * (1 - exp_minus_θh)
unoise = √(σ^2 / 2θ * (1 - exp_minus_θh^2)) * z
unext = umean + unoise
unext
end
function OU_AR_process_solve(u0, p, t, z)
@assert length(z) == length(t) - 1
u = zeros(length(t)) # preallocate
u[1] = u0
for i in 2:length(t)
u[i] = OU_AR_process(uprev = u[i-1], h = t[i] - t[i-1], p = p, z = z[i-1])
end
u
end

function OU_analytical(u0, p, t, dW)
(;μ, θ, σ) = p
exp_minus_θt = @. exp(-θ * t)
exp_θt_dW = @. exp(θ * t) * [0; dW]
∫exp_θt_dW = cumsum(exp_θt_dW)
@. u0 * exp_minus_θt + μ * (1 - exp_minus_θt) + σ * exp_minus_θt * ∫exp_θt_dW
end
OU_analytical(sol) = begin
u0 = sol.u[1][1]
p = sol.prob.p
OU_analytical(u0, p, sol.t, p.dW)
end

using GLMakie

# Example usage:
# ts, sols, an_sols, AR_sols = mysolve(n=10, prob = OU_prob(), analytical = OU_analytical, AR = OU_AR_process_solve)
# plot_sols(ts, sols, an_sols, AR_sols)
function mysolve(n=1; prob, analytical = nothing, AR = nothing, dt = 0.1)
u0 = prob.u0[1]
alg = IMEXAlgorithm(ARS111(), nothing); # Explicit Euler
sols = Vector{Float64}[]
an_sols = Vector{Float64}[]
AR_sols = Vector{Float64}[]
ts = Float64[]
for _ in 1:n
prob.u0[1] = u0 # reinit
sol = solve(prob, alg; dt, save_everystep = true)
ts = sol.t
push!(sols, vcat(sol.u...))
if !isnothing(analytical)
an_sol = analytical(sol)
push!(an_sols, vcat(an_sol...))
end
if !isnothing(AR)
z = sol.prob.p.dW / √dt # rescale to AR(1)
AR_sol = AR(u0, prob.p, sol.t, z)
push!(AR_sols, AR_sol)
end
empty!(prob.p.dW)
end
sols = hcat(sols...)
!isnothing(analytical) && (an_sols = hcat(an_sols...))
!isnothing(AR) && (AR_sols = hcat(AR_sols...))
return ts, sols, an_sols, AR_sols
end

function plot_sols(ts, sols, an_sols = nothing, AR_sols = nothing)
fig = Figure()
ax = Axis(fig[1,1])
lines!.(ax, Ref(ts), eachcol(sols); color=:black, label = "numerical")
isnothing(an_sols) || lines!.(ax, Ref(ts), eachcol(an_sols); color=:red, label = "analytical")
isnothing(AR_sols) || lines!.(ax, Ref(ts), eachcol(AR_sols); color=:blue, label = "AR(1)")
axislegend(; position=:lt, unique=true)
fig
end


"""
Single variable linear ODE

Expand All @@ -18,7 +132,7 @@ u(t) = u_0 e^{αt}

This is an in-place variant of the one from DiffEqProblemLibrary.jl.
"""
function linear_prob()
function linear_prob() ## nice thing
ODEProblem(
IncrementingODEFunction{true}((du, u, p, t, α = true, β = false) -> (du .= α .* p .* u .+ β .* du)),
[1 / 2],
Expand Down Expand Up @@ -539,7 +653,7 @@ function climacore_1Dheat_test_cts(::Type{FT}) where {FT}
return state
end

tendency_func = ClimaODEFunction(; T_exp!)
tendency_func = ClimaODEFunction(; T_exp!) # example of explicit-only ClimaODEFunction
split_tendency_func = tendency_func
make_prob(func) = ODEProblem(func, init_state, (FT(0), t_end), nothing)
IntegratorTestCase(
Expand Down
Loading