Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up generate_initializesystem() #3051

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 54 additions & 81 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,110 +5,83 @@ Generate `NonlinearSystem` which initializes an ODE problem from specified initi
"""
function generate_initializesystem(sys::ODESystem;
u0map = Dict(),
name = nameof(sys),
guesses = Dict(), check_defguess = false,
default_dd_value = 0.0,
algebraic_only = false,
initialization_eqs = [],
check_units = true,
kwargs...)
sts, eqs = unknowns(sys), equations(sys)
guesses = Dict(),
default_dd_guess = 0.0,
algebraic_only = false,
check_units = true, check_defguess = false,
name = nameof(sys), kwargs...)
vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)])
vars_set = Set(vars) # for efficient in-lookup

eqs = equations(sys)
idxs_diff = isdiffeq.(eqs)
idxs_alge = .!idxs_diff
num_alge = sum(idxs_alge)

# Start the equations list with algebraic equations
eqs_ics = eqs[idxs_alge]
u0 = Vector{Pair}(undef, 0)

# prepare map for dummy derivative substitution
eqs_diff = eqs[idxs_diff]
diffmap = Dict(getfield.(eqs_diff, :lhs) .=> getfield.(eqs_diff, :rhs))
observed_diffmap = Dict(Differential(get_iv(sys)).(getfield.((observed(sys)), :lhs)) .=>
Differential(get_iv(sys)).(getfield.((observed(sys)), :rhs)))
full_diffmap = merge(diffmap, observed_diffmap)
D = Differential(get_iv(sys))
diffmap = merge(
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys))
)

full_states = unique([sts; getfield.((observed(sys)), :lhs)])
set_full_states = Set(full_states)
guesses = todict(guesses)
# 1) process dummy derivatives and u0map into initialization system
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
guesses = merge(get_guesses(sys), todict(guesses))
schedule = getfield(sys, :schedule)

if schedule !== nothing
guessmap = [x[1] => get(guesses, x[1], default_dd_value)
for x in schedule.dummy_sub]
dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap))
if u0map === nothing || isempty(u0map)
filtered_u0 = u0map
else
filtered_u0 = Pair[]
for x in u0map
y = get(schedule.dummy_sub, x[1], x[1])
y = ModelingToolkit.fixpoint_sub(y, full_diffmap)

if y ∈ set_full_states
# defer initialization until defaults are merged below
push!(filtered_u0, y => x[2])
if !isnothing(schedule)
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
# set dummy derivatives to default_dd_guess unless specified
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
end
if !isnothing(u0map)
for (y, x) in u0map
y = get(schedule.dummy_sub, y, y)
y = fixpoint_sub(y, diffmap)
if y ∈ vars_set
# variables specified in u0 overrides defaults
push!(defs, y => x)
elseif y isa Symbolics.Arr
# scalarize array # TODO: don't scalarize arrays
_y = collect(y)
for i in eachindex(_y)
push!(filtered_u0, _y[i] => x[2][i])
end
# TODO: don't scalarize arrays
push!(defs, (collect(y) .=> x)...)
elseif y isa Symbolics.BasicSymbolic
# y is a derivative expression expanded
# add to the initialization equations
push!(eqs_ics, y ~ x[2])
# y is a derivative expression expanded; add it to the initialization equations
push!(eqs_ics, y ~ x)
else
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
end
end
filtered_u0 = todict(filtered_u0)
end
else
dd_guess = Dict()
filtered_u0 = todict(u0map)
end

defs = merge(defaults(sys), filtered_u0)
guesses = merge(get_guesses(sys), todict(guesses), dd_guess)

for st in full_states
if st ∈ keys(defs)
def = defs[st]

if def isa Equation
st ∉ keys(guesses) && check_defguess &&
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
push!(eqs_ics, def)
push!(u0, st => guesses[st])
else
push!(eqs_ics, st ~ def)
push!(u0, st => def)
end
elseif st ∈ keys(guesses)
push!(u0, st => guesses[st])
# 2) process other variables
for var in vars
if var ∈ keys(defs)
push!(eqs_ics, var ~ defs[var])
elseif var ∈ keys(guesses)
push!(defs, var => guesses[var])
elseif check_defguess
error("Invalid setup: unknown $(st) has no default value or initial guess")
error("Invalid setup: variable $(var) has no default value or initial guess")
end
end

# 3) process explicitly provided initialization equations
if !algebraic_only
for eq in [get_initialization_eqs(sys); initialization_eqs]
_eq = ModelingToolkit.fixpoint_sub(eq, full_diffmap)
push!(eqs_ics, _eq)
initialization_eqs = [get_initialization_eqs(sys); initialization_eqs]
for eq in initialization_eqs
eq = fixpoint_sub(eq, diffmap) # expand dummy derivatives
push!(eqs_ics, eq)
end
end

pars = [parameters(sys); get_iv(sys)]
nleqs = [eqs_ics; observed(sys)]

sys_nl = NonlinearSystem(nleqs,
full_states,
pars;
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess),
parameter_dependencies = parameter_dependencies(sys),
pars = [parameters(sys); get_iv(sys)] # include independent variable as pseudo-parameter
eqs_ics = [eqs_ics; observed(sys)]
return NonlinearSystem(
eqs_ics, vars, pars;
defaults = defs, parameter_dependencies = parameter_dependencies(sys),
checks = check_units,
name,
kwargs...)

return sys_nl
name, kwargs...
)
end
16 changes: 15 additions & 1 deletion test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ end
end

# https://github.com/SciML/ModelingToolkit.jl/issues/3029
@testset "Derivatives in Initialization Equations" begin
@testset "Derivatives in initialization equations" begin
@variables x(t)
sys = ODESystem(
[D(D(x)) ~ 0], t; initialization_eqs = [x ~ 0, D(x) ~ 1], name = :sys) |>
Expand All @@ -523,6 +523,20 @@ end
@test_nowarn ODEProblem(sys, [D(x) => 1.0], (0.0, 1.0), [])
end

# https://github.com/SciML/ModelingToolkit.jl/issues/3049
@testset "Derivatives in initialization guesses" begin
for sign in [-1.0, +1.0]
@variables x(t)
sys = ODESystem(
[D(D(x)) ~ 0], t;
initialization_eqs = [D(x)^2 ~ 1, x ~ 0], guesses = [D(x) => sign], name = :sys
) |> structural_simplify
prob = ODEProblem(sys, [], (0.0, 1.0), [])
sol = solve(prob, Tsit5())
@test sol(1.0, idxs = sys.x) ≈ sign # system with D(x(0)) = ±1 should solve to x(1) = ±1
end
end

# https://github.com/SciML/ModelingToolkit.jl/issues/2619
@parameters k1 k2 ω
@variables X(t) Y(t)
Expand Down
Loading