Skip to content

Commit

Permalink
Merge pull request #2995 from AayushSabharwal/as/callable-params
Browse files Browse the repository at this point in the history
feat: support callable parameters
  • Loading branch information
ChrisRackauckas authored Sep 21, 2024
2 parents 9aadc71 + 7579312 commit 90e6398
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 23 deletions.
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand Down Expand Up @@ -76,6 +77,7 @@ ChainRulesCore = "1"
Combinatorics = "1"
Compat = "3.42, 4"
ConstructionBase = "1"
DataInterpolations = "6.4"
DataStructures = "0.17, 0.18"
DeepDiffs = "1"
DiffEqBase = "6.103.0"
Expand All @@ -91,6 +93,7 @@ ExprTools = "0.1.10"
Expronicon = "0.8"
FindFirstFunctions = "1"
ForwardDiff = "0.10.3"
FunctionWrappers = "1.1"
FunctionWrappersWrappers = "0.1"
Graphs = "1.5.2"
InteractiveUtils = "1"
Expand Down Expand Up @@ -118,8 +121,8 @@ SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.29"
SymbolicUtils = "3.2"
Symbolics = "6.3"
SymbolicUtils = "3.7"
Symbolics = "6.12"
URIs = "1"
UnPack = "0.1, 1.0"
Unitful = "1.1"
Expand All @@ -129,6 +132,7 @@ julia = "1.9"
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -154,4 +158,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
4 changes: 3 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ using Base: RefValue
using Combinatorics
import Distributions
import FunctionWrappersWrappers
import FunctionWrappers: FunctionWrapper
using URIs: URI
using SciMLStructures
using Compat
Expand All @@ -63,7 +64,8 @@ using Symbolics: _parse_vars, value, @derivatives, get_variables,
VariableSource, getname, variable, Connection, connect,
NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval,
initial_state, transition, activeState, entry, hasnode,
ticksInState, timeInState, fixpoint_sub, fast_substitute
ticksInState, timeInState, fixpoint_sub, fast_substitute,
CallWithMetadata, CallWithParent
const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR)
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
jacobian_sparsity, isaffine, islinear, _iszero, _isone,
Expand Down
10 changes: 10 additions & 0 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ function isparameter(x)
end
end

function iscalledparameter(x)
x = unwrap(x)
return isparameter(getmetadata(x, CallWithParent, nothing))
end

function getcalledparameter(x)
x = unwrap(x)
return getmetadata(x, CallWithParent)
end

"""
toparam(s)
Expand Down
43 changes: 29 additions & 14 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct DiscreteIndex
end

const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}}
const NonnumericMap = Dict{
Union{BasicSymbolic, Symbolics.CallWithMetadata}, Tuple{Int, Int}}
const UnknownIndexMap = Dict{
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
const TunableIndexMap = Dict{BasicSymbolic,
Expand All @@ -45,20 +47,20 @@ struct IndexCache
callback_to_clocks::Dict{Any, Vector{Int}}
tunable_idx::TunableIndexMap
constant_idx::ParamIndexMap
nonnumeric_idx::ParamIndexMap
nonnumeric_idx::NonnumericMap
observed_syms::Set{BasicSymbolic}
dependent_pars::Set{BasicSymbolic}
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
tunable_buffer_size::BufferTemplate
constant_buffer_sizes::Vector{BufferTemplate}
nonnumeric_buffer_sizes::Vector{BufferTemplate}
symbol_to_variable::Dict{Symbol, BasicSymbolic}
symbol_to_variable::Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}
end

function IndexCache(sys::AbstractSystem)
unks = solved_unknowns(sys)
unk_idxs = UnknownIndexMap()
symbol_to_variable = Dict{Symbol, BasicSymbolic}()
symbol_to_variable = Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}()

let idx = 1
for sym in unks
Expand Down Expand Up @@ -105,12 +107,11 @@ function IndexCache(sys::AbstractSystem)

tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
nonnumeric_buffers = Dict{Any, Set{BasicSymbolic}}()
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()

function insert_by_type!(buffers::Dict{Any, Set{BasicSymbolic}}, sym)
function insert_by_type!(buffers::Dict{Any, S}, sym, ctype) where {S}
sym = unwrap(sym)
ctype = symtype(sym)
buf = get!(buffers, ctype, Set{BasicSymbolic}())
buf = get!(buffers, ctype, S())
push!(buf, sym)
end

Expand Down Expand Up @@ -142,7 +143,7 @@ function IndexCache(sys::AbstractSystem)
clocks = get!(() -> Set{Int}(), disc_param_callbacks, sym)
push!(clocks, i)
else
insert_by_type!(constant_buffers, sym)
insert_by_type!(constant_buffers, sym, symtype(sym))
end
end
end
Expand Down Expand Up @@ -197,6 +198,9 @@ function IndexCache(sys::AbstractSystem)
for p in parameters(sys)
p = unwrap(p)
ctype = symtype(p)
if ctype <: FnType
ctype = fntype_to_function_type(ctype)
end
haskey(disc_idxs, p) && continue
haskey(constant_buffers, ctype) && p in constant_buffers[ctype] && continue
insert_by_type!(
Expand All @@ -212,12 +216,13 @@ function IndexCache(sys::AbstractSystem)
else
nonnumeric_buffers
end,
p
p,
ctype
)
end

function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}})
idxs = ParamIndexMap()
function get_buffer_sizes_and_idxs(T, buffers::Dict)
idxs = T()
buffer_sizes = BufferTemplate[]
for (i, (T, buf)) in enumerate(buffers)
for (j, p) in enumerate(buf)
Expand All @@ -229,13 +234,18 @@ function IndexCache(sys::AbstractSystem)
idxs[rp] = (i, j)
idxs[rttp] = (i, j)
end
if T <: Symbolics.FnType
T = Any
end
push!(buffer_sizes, BufferTemplate(T, length(buf)))
end
return idxs, buffer_sizes
end

const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers)
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers)
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(
ParamIndexMap, constant_buffers)
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(
NonnumericMap, nonnumeric_buffers)

tunable_idxs = TunableIndexMap()
tunable_buffer_size = 0
Expand Down Expand Up @@ -401,7 +411,8 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
for temp in ic.discrete_buffer_sizes)
const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
for temp in ic.constant_buffer_sizes)
nonnumeric_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
nonnumeric_buf = Tuple(Union{BasicSymbolic, CallWithMetadata}[unwrap(variable(:DEF))
for _ in 1:(temp.length)]
for temp in ic.nonnumeric_buffer_sizes)
for p in ps
p = unwrap(p)
Expand Down Expand Up @@ -481,3 +492,7 @@ function get_buffer_template(ic::IndexCache, pidx::ParameterIndex)
error("Unhandled portion $portion")
end
end

fntype_to_function_type(::Type{FnType{A, R, T}}) where {A, R, T} = T
fntype_to_function_type(::Type{FnType{A, R, Nothing}}) where {A, R} = FunctionWrapper{R, A}
fntype_to_function_type(::Type{FnType{A, R}}) where {A, R} = FunctionWrapper{R, A}
3 changes: 3 additions & 0 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ function MTKParameters(
if symbolic_type(val) !== NotSymbolic()
error("Could not evaluate value of parameter $sym. Missing values for variables in expression $val.")
end
if ctype <: FnType
ctype = fntype_to_function_type(ctype)
end
val = symconvert(ctype, val)
done = set_value(sym, val)
if !done && Symbolics.isarraysymbolic(sym)
Expand Down
3 changes: 2 additions & 1 deletion src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,8 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
end
ps = [setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous))
ps = [sym isa CallWithMetadata ? sym :
setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous))
for sym in get_ps(sys)]
@set! sys.ps = ps
else
Expand Down
12 changes: 10 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,11 @@ end
vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op)
vars(exprs::Symbolics.Arr; op = Differential) = vars(unwrap(exprs); op)
function vars(exprs; op = Differential)
foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set())
if hasmethod(iterate, Tuple{typeof(exprs)})
foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set())
else
vars!(Set(), unwrap(exprs); op)
end
end
vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op)
function vars!(vars, eq::Equation; op = Differential)
Expand Down Expand Up @@ -479,7 +483,11 @@ end

function collect_var!(unknowns, parameters, var, iv)
isequal(var, iv) && return nothing
if isparameter(var) || (iscall(var) && isparameter(operation(var)))
if iscalledparameter(var)
callable = getcalledparameter(var)
push!(parameters, callable)
collect_vars!(unknowns, parameters, arguments(var), iv)
elseif isparameter(var) || (iscall(var) && isparameter(operation(var)))
push!(parameters, var)
elseif !isconstant(var)
push!(unknowns, var)
Expand Down
4 changes: 2 additions & 2 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ eqs = [D(x) ~ σ(t - 1) * (y - x),
D(y) ~ x *- z) - y,
D(z) ~ x * y - β * z * κ]
@named de = ODESystem(eqs, t)
test_diffeq_inference("single internal iv-varying", de, t, (x, y, z), (σ(t - 1), ρ, β))
test_diffeq_inference("single internal iv-varying", de, t, (x, y, z), (σ, ρ, β))
f = eval(generate_function(de, [x, y, z], [σ, ρ, β])[2])
du = [0.0, 0.0, 0.0]
f(du, [1.0, 2.0, 3.0], [x -> x + 7, 2, 3], 5.0)
@test du [11, -3, -7]

eqs = [D(x) ~ x + 10σ(t - 1) + 100σ(t - 2) + 1000σ(t^2)]
@named de = ODESystem(eqs, t)
test_diffeq_inference("many internal iv-varying", de, t, (x,), (σ(t - 2), σ(t^2), σ(t - 1)))
test_diffeq_inference("many internal iv-varying", de, t, (x,), (σ,))
f = eval(generate_function(de, [x], [σ])[2])
du = [0.0]
f(du, [1.0], [t -> t + 2], 5.0)
Expand Down
45 changes: 45 additions & 0 deletions test/split_parameters.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
using ModelingToolkit, Test
using ModelingToolkitStandardLibrary.Blocks
using OrdinaryDiffEq
using DataInterpolations
using BlockArrays: BlockedArray
using ModelingToolkit: t_nounits as t, D_nounits as D
using ModelingToolkit: MTKParameters, ParameterIndex, NONNUMERIC_PORTION
using SciMLStructures: Tunable, Discrete, Constants
using StaticArrays: SizedVector
using SymbolicIndexingInterface: is_parameter, getp

x = [1, 2.0, false, [1, 2, 3], Parameter(1.0)]

Expand Down Expand Up @@ -219,3 +221,46 @@ S = get_sensitivity(closed_loop, :u)
@test ps[ParameterIndex(Tunable(), 1:8)] == collect(1.0:8.0) .+ 0.5
@test ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] == 5
end

@testset "Callable parameters" begin
@testset "As FunctionWrapper" begin
_f1(x) = 2x
struct Foo end
(::Foo)(x) = 3x
@variables x(t)
@parameters fn(::Real) = _f1
@mtkbuild sys = ODESystem(D(x) ~ fn(t), t)
@test is_parameter(sys, fn)
@test ModelingToolkit.defaults(sys)[fn] == _f1

getter = getp(sys, fn)
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
@inferred getter(prob)
# cannot be inferred better since `FunctionWrapper` is only known to return `Real`
@inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1])
sol = solve(prob, Tsit5(); abstol = 1e-10, reltol = 1e-10)
@test sol.u[end][] 2.0

prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => Foo()])
@inferred getter(prob)
@inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1])
sol = solve(prob; abstol = 1e-10, reltol = 1e-10)
@test sol.u[end][] 2.5
end

@testset "Concrete function type" begin
ts = 0.0:0.1:1.0
interp = LinearInterpolation(ts .^ 2, ts; extrapolate = true)
@variables x(t)
@parameters (fn::typeof(interp))(..)
@mtkbuild sys = ODESystem(D(x) ~ fn(x), t)
@test is_parameter(sys, fn)
getter = getp(sys, fn)
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => interp])
@inferred getter(prob)
@inferred prob.f(prob.u0, prob.p, prob.tspan[1])
@test_nowarn sol = solve(prob, Tsit5())
@test_nowarn prob.ps[fn] = LinearInterpolation(ts .^ 3, ts; extrapolate = true)
@test_nowarn sol = solve(prob)
end
end

0 comments on commit 90e6398

Please sign in to comment.