Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dt_cache infrastructure for multirate #26

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 83 additions & 11 deletions src/integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ end

# called by DiffEqBase.init and solve (see below)
function DiffEqBase.__init(
prob::DiffEqBase.AbstractODEProblem,
alg::DistributedODEAlgorithm,
args...;
prob::DiffEqBase.AbstractODEProblem,
alg::DistributedODEAlgorithm,
args...;
dt, # required
stepstop=-1,
stepstop=-1,
adjustfinal=false,
callback=nothing,
kwargs...)
kwargs...)

u = prob.u0
t = prob.tspan[1]
tstop = prob.tspan[2]

callbackset = DiffEqBase.CallbackSet(callback)
isempty(callbackset.continuous_callbacks) || error("Continuous callbacks are not supported")
integrator = DistributedODEIntegrator(prob, alg, u, dt, t, tstop, 0, stepstop, adjustfinal, callbackset, false, cache(prob, alg; dt=dt, kwargs...))
integrator = DistributedODEIntegrator(prob, alg, u, dt, t, tstop, 0, stepstop, adjustfinal, callbackset, false, init_cache(prob, alg; dt=dt, kwargs...))

DiffEqBase.initialize!(callbackset,u,t,integrator)
return integrator
Expand All @@ -46,10 +46,10 @@ end
# called by DiffEqBase.solve
function DiffEqBase.__solve(
prob::DiffEqBase.AbstractODEProblem,
alg::DistributedODEAlgorithm,
alg::DistributedODEAlgorithm,
args...;
kwargs...)

integrator = DiffEqBase.__init(prob, alg, args...; kwargs...)
DiffEqBase.solve!(integrator)
return integrator.u # ODEProblem returns a Solution objec
Expand All @@ -61,6 +61,10 @@ function DiffEqBase.solve!(integrator::DistributedODEIntegrator)
if integrator.adjustfinal && integrator.t + integrator.dt > integrator.tstop
adjust_dt!(integrator, integrator.tstop - integrator.t)
end
if !integrator.adjustfinal && integrator.t + integrator.dt/2 > integrator.tstop
break
end

DiffEqBase.step!(integrator)

if integrator.step == integrator.stepstop
Expand Down Expand Up @@ -90,13 +94,81 @@ function DiffEqBase.step!(integrator::DistributedODEIntegrator)
end

# solvers need to define this interface
step_u!(integrator) = step_u!(integrator, integrator.cache)
step_u!(integrator) = step_u!(integrator, integrator.cache)

"""
adjust_dt!(integrator::DistributedODEIntegrator, dt[, dt_cache=nothing])

function adjust_dt!(integrator::DistributedODEIntegrator, dt)
Adjust the time step of the integrator to `dt`. The optional `dt_cache` object
can be passed when the integrator has a `dt`-dependent component that needs to
be updated (such as a linear solver).
"""
function adjust_dt!(integrator::DistributedODEIntegrator, dt, dt_cache=nothing)
# TODO: figure out interface for recomputing other objects (linear operators, etc)
integrator.dt = dt
adjust_dt!(integrator.cache, dt, dt_cache)
end

# interfaces

"""
init_cache(prob, alg::A; kwargs...)::AC

Construct an algorithm cache for the algorithm `alg`. This should be defined
for any algorithm type `A`, and should return an object of an appropriate cache
type `AC` that can be dispatched on for [`step_u!`](@ref) and/or
[`init_inner`](@ref)/[`update_inner!`](@ref).
"""
function init_cache end

"""
step_u!(integrator, cache::AC)

Perform a single step that updates the state `integrator.u` using accordint to
the algorithm corresponding to `cache`.

This should be defined for any algorithm cache type `AC` that can be used
directly or as an inner timestepper. For outer timesteppers,
[`init_inner`](@ref) and [`update_inner!`](@ref) need to be defined instead.
"""
step_u!(integrator, cache)

"""
init_dt_cache(cache::AC, prob, dt)

Construct a `dt`-dependent subcache of `cache` for the ODE problem `prob`. This
should _not_ modify `cache` itself, but return an object that can be passed as
the `dt_cache` argument to [`adjust_dt!`](@ref).

By default this returns `nothing`. This should be defined for any algorithm
cache type `AC` which has `dt`-dependent components.

For example, an implicit solver can use this to return a factorized Euler
operator ``I-dt*L`` that is used as part of the implicit solve.

This initialization will typically be done as part of [`init_cache`](@ref)
itself: this interface is provided for multirate schemes which need to modify
the `dt` of the inner solver at each outer stage.
"""
function init_dt_cache(cache, prob, dt)
return nothing
end


function get_dt_cache(cache)
return nothing
end

"""
adjust_dt!(cache::AC, dt, dt_cache)

Adjust the time step of the algorithm cache `cache`. This should be defined for
any algorithm cache type `AC`, where `dt_cache` is an object returned by
[`init_dt_cache`](@ref).
"""
adjust_dt!(cache, dt, dt_cache)



# not sure what this should do?
# defined as default initialize: https://github.com/SciML/DiffEqBase.jl/blob/master/src/callbacks.jl#L3
Expand Down
39 changes: 29 additions & 10 deletions src/solvers/ark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ struct AdditiveRungeKuttaTableau{Nstages, Nstages², RT}
C::NTuple{Nstages, RT}
end

struct AdditiveRungeKuttaFullCache{Nstages, RT, A, O, L}
struct AdditiveRungeKuttaFullCache{Nstages,RT, A, G, O, L}
alg::G
"stage value of the state variable"
U::A #Qstages
"evaluated linear part of each stage ``f_L(U^{(i)})``"
Expand All @@ -38,8 +39,32 @@ struct AdditiveRungeKuttaFullCache{Nstages, RT, A, O, L}
linsolve!::L
end

function implicit_part(f::DiffEqBase.ODEFunction)
f.jvp === nothing && error("IMEX solvers require a `SplitODEFunction` or an `ODEFunction` with a `jvp` component.")
return f.jvp
end
implicit_part(f::DiffEqBase.SplitFunction) = f.f1
implicit_part(f::OffsetODEFunction) = implicit_part(f.f)

function cache(
function init_dt_cache(cache::AdditiveRungeKuttaFullCache, prob, dt)
_init_dt_cache(cache.alg, cache.tableau, prob, dt)
end
function _init_dt_cache(alg::AdditiveRungeKutta, tab, prob, dt)
f_impl = implicit_part(prob.f)
W = EulerOperator(f_impl , -dt*tab.Aimpl[2,2], prob.p, prob.tspan[1])
linsolve! = alg.linsolve(Val{:init}, W, prob.u0)
return (W, linsolve!)
end

function get_dt_cache(cache::AdditiveRungeKuttaFullCache)
return (cache.W, cache.linsolve!)
end
function adjust_dt!(cache::AdditiveRungeKuttaFullCache, dt, (W, linsolve!)::Tuple)
cache.W = W
cache.linsolve! = linsolve!
end

function init_cache(
prob::DiffEqBase.AbstractODEProblem{uType, tType, true},
alg::AdditiveRungeKutta; dt, kwargs...) where {uType,tType}

Expand All @@ -49,14 +74,8 @@ function cache(
L = ntuple(i -> zero(prob.u0), Nstages)
R = ntuple(i -> zero(prob.u0), Nstages)

if prob.f isa DiffEqBase.ODEFunction
W = EulerOperator(prob.f.jvp, -dt*tab.Aimpl[2,2], prob.p, prob.tspan[1])
elseif prob.f isa DiffEqBase.SplitFunction
W = EulerOperator(prob.f.f1, -dt*tab.Aimpl[2,2], prob.p, prob.tspan[1])
end
linsolve! = alg.linsolve(Val{:init}, W, prob.u0; kwargs...)

AdditiveRungeKuttaFullCache(U, L, R, tab, W, linsolve!)
W, linsolve! = _init_dt_cache(alg, tab, prob, dt)
AdditiveRungeKuttaFullCache(alg, U, L, R, tab, W, linsolve!)
end


Expand Down
15 changes: 13 additions & 2 deletions src/solvers/lsrk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct LowStorageRungeKutta2NIncCache{Nstages, RT, A}
du::A
end

function cache(prob::DiffEqBase.ODEProblem, alg::LowStorageRungeKutta2N; kwargs...)
function init_cache(prob::DiffEqBase.ODEProblem, alg::LowStorageRungeKutta2N; kwargs...)
# @assert prob.problem_type isa DiffEqBase.IncrementingODEProblem ||
# prob.f isa DiffEqBase.IncrementingODEFunction
du = zero(prob.u0)
Expand All @@ -59,8 +59,19 @@ function step_u!(int, cache::LowStorageRungeKutta2NIncCache)
end
end

adjust_dt!(cache::LowStorageRungeKutta2NIncCache, dt, ::Nothing) = nothing

# for Multirate
function init_inner(prob, outercache::LowStorageRungeKutta2NIncCache, dt)
function inner_dts(outercache::LowStorageRungeKutta2NIncCache, dt, fast_dt)
N = nstages(outercache)
tab = outercache.tableau
ntuple(N) do i
Δt = (i == N ? 1-tab.C[i] : tab.C[i+1] - tab.C[i]) * dt
Δt / round(Δt / fast_dt)
end
end

function init_inner_fun(prob, outercache::LowStorageRungeKutta2NIncCache, dt)
OffsetODEFunction(prob.f.f1, zero(dt), one(dt), zero(dt), outercache.du)
end
function update_inner!(innerinteg, outercache::LowStorageRungeKutta2NIncCache,
Expand Down
11 changes: 9 additions & 2 deletions src/solvers/mis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ end

nstages(::MultirateInfinitesimalStepCache{Nstages}) where {Nstages} = Nstages

function cache(
function init_cache(
prob::DiffEqBase.AbstractODEProblem{uType, tType, true},
alg::MultirateInfinitesimalStep; kwargs...) where {uType,tType}

Expand All @@ -66,8 +66,15 @@ function cache(
return MultirateInfinitesimalStepCache(ΔU, F, tab)
end

function inner_dts(outercache::MultirateInfinitesimalStepCache, dt, fast_dt)
tab = outercache.tableau
map(tab.d) do d_i
Δt = d_i*dt
Δt / round(Δt / fast_dt)
end
end

function init_inner(prob, outercache::MultirateInfinitesimalStepCache, dt)
function init_inner_fun(prob, outercache::MultirateInfinitesimalStepCache, dt)
OffsetODEFunction(prob.f.f1, zero(dt), one(dt), one(dt), outercache.ΔU[end])
end

Expand Down
110 changes: 89 additions & 21 deletions src/solvers/multirate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ struct Multirate{F,S} <: DistributedODEAlgorithm
end


struct MultirateCache{OC,II}
struct MultirateCache{OC,II,SD}
outercache::OC
innerinteg::II
dt_cache::SD
end

function cache(
function init_cache(
prob::DiffEqBase.AbstractODEProblem,
alg::Multirate;
dt, fast_dt, kwargs...)
Expand All @@ -33,13 +34,49 @@ function cache(

# subproblems
outerprob = DiffEqBase.remake(prob; f=prob.f.f2)
outercache = cache(outerprob, alg.slow)
outercache = init_cache(outerprob, alg.slow)

innerfun = init_inner(prob, outercache, dt)
sub_dts = inner_dts(outercache, dt, fast_dt)
unique_sub_dts = unique(sub_dts)

innerfun = init_inner_fun(prob, outercache, dt)
innerprob = DiffEqBase.remake(prob; f=innerfun)
innerinteg = DiffEqBase.init(innerprob, alg.fast; dt=fast_dt, kwargs...)
return MultirateCache(outercache, innerinteg)
innerinteg = DiffEqBase.init(innerprob, alg.fast; dt=unique_sub_dts[1], adjustfinal=false, kwargs...)

# build dt_cache
unique_dt_caches = [
i == 1 ? get_dt_cache(innerinteg.cache) : init_dt_cache(innerinteg.cache, innerinteg.prob, unique_sub_dts[i])
for i = 1:length(unique_sub_dts)]

dt_cache = map(sub_dts) do sub_dt
i = findfirst(==(sub_dt), unique_sub_dts)
unique_sub_dts[i] => unique_dt_caches[i]
end

return MultirateCache(outercache, innerinteg, dt_cache)
end

get_dt_cache(cache::Multirate) = cache.dt_cache
function init_dt_cache(cache::Multirate, prob, dt)
outercache = cache.outercache
innerinteg = cache.innerinteg

fast_dt = innerinteg.dt # TODO: get the original fast_dt from somewhere

sub_dts = inner_dts(outercache, dt, fast_dt)
unique_sub_dts = unique(sub_dts)

unique_dt_caches = [
init_dt_cache(innerinteg.cache, innerinteg.prob, unique_sub_dts[i])
for i = 1:length(unique_sub_dts)]

dt_cache = map(sub_dts) do sub_dt
i = findfirst(==(sub_dt), unique_sub_dts)
unique_sub_dts[i] => unique_dt_caches[i]
end
return dt_cache
end
adjust_dt!(cache::Multirate, dt, dt_cache::Tuple) = cache.dt_cache


function step_u!(int, cache::MultirateCache)
Expand All @@ -54,23 +91,54 @@ function step_u!(int, cache::MultirateCache)
innerinteg = cache.innerinteg
fast_dt = innerinteg.dt

N = nstages(outercache)
for stage in 1:N
for i in 1:nstages(outercache)
sub_dt, sub_dt_cache = cache.dt_cache[i]
adjust_dt!(innerinteg, sub_dt, sub_dt_cache)
update_inner!(innerinteg, outercache, int.prob.f.f2, u, p, t, dt, i)
DiffEqBase.solve!(innerinteg)
end
end

# interface
"""
nstages(outercache::AC)

update_inner!(innerinteg, outercache, int.prob.f.f2, u, p, t, dt, stage)
The number of stages of the algorithm determined by cache type `AC`. This should
be defined for any algorithm cache type `AC` used as an outer solver.
"""
function nstages end

# solve inner problem
# dv/dτ .= B[s]/(C[s+1] - C[s]) .* du .+ f_fast(v,τ) τ ∈ [τ0,τ1]

# TODO: make this more generic
# there are 2 strategies we can use here:
# a. use same fast_dt for all slow stages, use `adjustfinal=true`
# - problems for ARK (e.g. requires expensive LU factorization)
# b. use different fast_dt, cache expensive ops
"""
inner_dts(outercache::AC, dt, fast_dt)

innerinteg.adjustfinal = true
DiffEqBase.solve!(innerinteg)
innerinteg.dt = fast_dt # reset
end
end
The inner timesteps that will be used at each stage of the multirate procedure.

This should be defined for any algorithm cache type `AC` that will be used as an
outer solver, and should return a tuple of the length of the number of stages.
Each value will be approximately `fast_dt`, but rounded so that an integer
number of steps can be used at each outer stage (where `dt` is the slow time
step).
"""
function inner_dts end

"""
init_inner_fun(prob, outercache::AC, dt)

Construct the inner `ODEFunction` that will be used with inner solver. This
should be defined for any algorithm cache type `AC` that will be used as an
outer solver.
"""
function init_inner_fun end

"""
update_inner!(innerinteg, outercache::AC, f_slow, u, p, t, dt, i)

Update the inner integrator `innerinteg` for stage `i` of the outer algorithm.
This should be defined for any `outercache` type `AC`, and will typically modify:
- `innerinteg.prob.f`
- `innerinteg.u`
- `innerinteg.t`
- `innerinteg.tstop`
"""
function update_inner! end
Loading