Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @guess to @mtkmodel + add defaults and guess keywords to MTKModel.f #2709

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 61 additions & 16 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
:constants => Dict{Symbol, Dict}(),
:defaults => Dict{Symbol, Any}(),
:kwargs => Dict{Symbol, Dict}(),
:guesses => Dict{Symbol, Any}(),
:structural_parameters => Dict{Symbol, Dict}()
)
comps = Union{Symbol, Expr}[]
Expand All @@ -57,7 +58,8 @@
push!(exprs.args, :(parameters = []))
push!(exprs.args, :(systems = ODESystem[]))
push!(exprs.args, :(equations = Equation[]))
push!(exprs.args, :(defaults = Dict{Num, Union{Number, Symbol, Function}}()))
push!(exprs.args, :(_defaults = Dict{Num, Union{Number, Symbol, Function}}()))
push!(exprs.args, :(_guesses = Dict{Num, Union{Number, Symbol, Function}}()))

Base.remove_linenums!(expr)
for arg in expr.args
Expand Down Expand Up @@ -106,8 +108,12 @@
@inline pop_structure_dict!.(
Ref(dict), [:constants, :defaults, :kwargs, :structural_parameters])


push!(exprs.args, @inline sanitize_default_guess_kwargs(:defaults, :_defaults))
push!(exprs.args, @inline sanitize_default_guess_kwargs(:guesses, :_guesses))

sys = :($ODESystem($Equation[equations...], $iv, variables, parameters;
name, systems, gui_metadata = $gui_metadata, defaults))
name, systems, gui_metadata = $gui_metadata, defaults = _defaults, guesses = _guesses))

if ext[] === nothing
push!(exprs.args, :(var"#___sys___" = $sys))
Expand All @@ -128,18 +134,40 @@
$(d_evts...)
]))))

Base.remove_linenums!(exprs)

f = if length(where_types) == 0
:($(Symbol(:__, name, :__))(; name, $(kwargs...)) = $exprs)
:($(Symbol(:__, name, :__))(;
name,
defaults::Union{NamedTuple, Dict} = Dict{Union{Num, Symbol}, Union{Number, Symbol, Function}}(),
guesses::Union{NamedTuple, Dict} = Dict{Union{Num, Symbol}, Union{Number, Symbol, Function}}(),
$(kwargs...)) = $exprs)
else
f_with_where = Expr(:where)
push!(f_with_where.args,
:($(Symbol(:__, name, :__))(; name, $(kwargs...))), where_types...)
:($(Symbol(:__, name, :__))(;
name,
defaults = Dict{Union{Num, Symbol}, Union{Number, Symbol, Function}}(),
guesses = Dict{Union{Num, Symbol}, Union{Number, Symbol, Function}}(),
$(kwargs...))), where_types...)
:($f_with_where = $exprs)
end

:($name = $Model($f, $dict, $isconnector))
end

function sanitize_default_guess_kwargs(type, target)
return quote
$type isa NamedTuple && ($type = Dict(pairs($type)))
for (var"##k", var"##v") in $type
var"##var" = filter!( var"##p" -> nameof(var"##p") == var"##k", vcat(parameters, variables))
if length(var"##var") == 1 # Variable can be present once or can be absent
$target[var"##var"[1]] = var"##v"

Check warning on line 165 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L161-L165

Added lines #L161 - L165 were not covered by tests
end
end

Check warning on line 167 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L167

Added line #L167 was not covered by tests
end
end

pop_structure_dict!(dict, key) = length(dict[key]) == 0 && pop!(dict, key)

function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
Expand Down Expand Up @@ -407,7 +435,9 @@
isassigned(icon) && error("This model has more than one icon.")
parse_icon!(body, dict, icon, mod)
elseif mname == Symbol("@defaults")
parse_system_defaults!(exprs, arg, dict)
parse_system_defaults_guesses!(exprs, arg, dict, :defaults)
elseif mname == Symbol("@guesses")
parse_system_defaults_guesses!(exprs, arg, dict, :guesses)

Check warning on line 440 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L439-L440

Added lines #L439 - L440 were not covered by tests
else
error("$mname is not handled.")
end
Expand Down Expand Up @@ -456,24 +486,25 @@
end
end

push_additional_defaults!(dict, a, b::Number) = dict[:defaults][a] = b
push_additional_defaults!(dict, a, b::QuoteNode) = dict[:defaults][a] = b.value
function push_additional_defaults!(dict, a, b::Expr)
dict[:defaults][a] = readable_code(b)
push_additional_defaults_guesses!(dict, type, a, b::Number) = dict[type][a] = b
push_additional_defaults_guesses!(dict, type, a, b::QuoteNode) = dict[type][a] = b.value
function push_additional_defaults_guesses!(dict, type, a, b::Expr)
dict[type][a] = readable_code(b)

Check warning on line 492 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L490-L492

Added lines #L490 - L492 were not covered by tests
end

function parse_system_defaults!(exprs, defaults_body, dict)
for default_arg in defaults_body.args[end].args
function parse_system_defaults_guesses!(exprs, dg_body, dict, type)
for dg_arg in dg_body.args[end].args
# for arg in default_arg.args
MLStyle.@match default_arg begin
MLStyle.@match dg_arg begin
# For cases like `p => 1` and `p => f()`. In both cases the definitions of
# `a`, here `p` and when `b` is a function, here `f` are available while
# defining the model
Expr(:call, :(=>), a, b) => begin
push!(exprs, :(defaults[$a] = $b))
push_additional_defaults!(dict, a, b)
_type = Symbol(:_, type)
push!(exprs, :($_type[$a] = $b))
push_additional_defaults_guesses!(dict, type, a, b)
end
_ => error("Invalid `defaults` entry $default_arg $(typeof(a)) $(typeof(b))")
_ => error("Invalid `$type` entry $dg_arg $(typeof(a)) $(typeof(b))")

Check warning on line 507 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L507

Added line #L507 was not covered by tests
end
end
end
Expand Down Expand Up @@ -557,7 +588,15 @@
function _parse_extend!(ext, a, b, dict, expr, kwargs, varexpr, vars)
extend_args!(a, b, dict, expr, kwargs, varexpr)
ext[] = a
push!(b.args, Expr(:kw, :name, Meta.quot(a)))
if length(b.args) > 1
if b.args[2].head == :parameters
push!(b.args[2].args, :defaults, :guesses, Expr(:kw, :name, Meta.quot(a)))
elseif b.args[2].head == :kw
b.args[2] = Expr(:parameters, :defaults, :guesses, Expr(:kw, :name, Meta.quot(a)), b.args[2:end]...)

Check warning on line 595 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L594-L595

Added lines #L594 - L595 were not covered by tests
end
else
push!(b.args, Expr(:parameters, :defaults, :guesses, Expr(:kw, :name, Meta.quot(a))))
end
push!(expr.args, :($a = $b))

dict[:extend] = [Symbol.(vars.args), a, b.args[1]]
Expand Down Expand Up @@ -587,6 +626,12 @@
Expr(:call, a′, _...) => begin
a = Symbol(Symbol("#mtkmodel"), :__anonymous__, a′)
b = body
@info 630 b.args[1] mod
try
getproperty(mod, b.args[1])
catch e
@info e

Check warning on line 633 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L633

Added line #L633 was not covered by tests
end
if (model = getproperty(mod, b.args[1])) isa Model
vars = Expr(:tuple)
append!(vars.args, names(model))
Expand Down
44 changes: 42 additions & 2 deletions test/model_parsing.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using ModelingToolkit, Test
using ModelingToolkit: get_connector_type, get_defaults, get_gui_metadata,
get_systems, get_ps, getdefault, getname, readable_code,
using ModelingToolkit: defaults, get_connector_type, get_defaults, get_gui_metadata,
get_systems, get_ps, getdefault, getname, guesses, readable_code,
scalarize, symtype, VariableDescription, RegularConnector
using URIs: URI
using Distributions
Expand Down Expand Up @@ -276,6 +276,46 @@ end
@test MockModel.structure[:defaults] == Dict(:n => 1.0, :n2 => "g()")
end

@testset "Defaults and Guesses" begin
@mtkmodel ToExtend begin
@parameters begin
t1
t2
end
end

@mtkmodel DefaultGuessModel begin
@extend ToExtend()#t1 = 0; t2 = 0)
@parameters begin
d
g
end
@defaults begin
d => 10
end
@guesses begin
g => 20
end
end

@named dg1 = DefaultGuessModel()
dg1 = complete(dg1)
@test defaults(dg1)[dg1.d] == 10
@test guesses(dg1)[dg1.g] == 20

@named dg2 = DefaultGuessModel(defaults = Dict(:d => 11, :t1 => 1), guesses = Dict(:g => 21))
dg2 = complete(dg2)
@test defaults(dg2)[dg2.d] == 11
@test defaults(dg2)[dg2.t1] == 1
@test guesses(dg2)[dg2.g] == 21

@named dg3 = DefaultGuessModel(defaults = (d = 12,), guesses = (g = 22, t1 = 2))
dg3 = complete(dg3)
@test defaults(dg3)[dg3.d] == 12
@test guesses(dg3)[dg3.g] == 22
@test guesses(dg3)[dg3.t1] == 2
end

@testset "Type annotation" begin
@mtkmodel TypeModel begin
@structural_parameters begin
Expand Down
Loading