Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow users to set array length via args in @mtkmodel #3055

Merged
merged 7 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions docs/src/basics/MTKLanguage.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,14 @@ end
@structural_parameters begin
f = sin
N = 2
M = 3
end
begin
v_var = 1.0
end
@variables begin
v(t) = v_var
v_array(t)[1:2, 1:3]
v_array(t)[1:N, 1:M]
v_for_defaults(t)
end
@extend ModelB(; p1)
Expand Down Expand Up @@ -310,10 +311,10 @@ end
- `:defaults`: Dictionary of variables and default values specified in the `@defaults`.
- `:extend`: The list of extended unknowns, name given to the base system, and name of the base system.
- `:structural_parameters`: Dictionary of structural parameters mapped to their metadata.
- `:parameters`: Dictionary of symbolic parameters mapped to their metadata. For
parameter arrays, length is added to the metadata as `:size`.
- `:variables`: Dictionary of symbolic variables mapped to their metadata. For
variable arrays, length is added to the metadata as `:size`.
- `:parameters`: Dictionary of symbolic parameters mapped to their metadata. Metadata of
the parameter arrays is, for now, omitted.
- `:variables`: Dictionary of symbolic variables mapped to their metadata. Metadata of
the variable arrays is, for now, omitted.
- `:kwargs`: Dictionary of keyword arguments mapped to their metadata.
- `:independent_variable`: Independent variable, which is added while generating the Model.
- `:equations`: List of equations (represented as strings).
Expand All @@ -324,10 +325,10 @@ For example, the structure of `ModelC` is:
julia> ModelC.structure
Dict{Symbol, Any} with 10 entries:
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA], Union{Expr, Symbol}[:model_array_a, :ModelA, :(1:N)], Union{Expr, Symbol}[:model_array_b, :ModelA, :(1:N)]]
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_array=>Dict(:type=>Real, :size=>(2, 3)), :v_for_defaults=>Dict(:type=>Real))
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_for_defaults=>Dict(:type=>Real))
:icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png")
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2), :v=>Dict{Symbol, Any}(:value=>:v_var, :type=>Real), :v_array=>Dict{Symbol, Union{Nothing, UnionAll}}(:value=>nothing, :type=>AbstractArray{Real}), :v_for_defaults=>Dict{Symbol, Union{Nothing, DataType}}(:value=>nothing, :type=>Real), :p1=>Dict(:value=>nothing))
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2))
:kwargs => Dict{Symbol, Dict}(:f => Dict(:value => :sin), :N => Dict(:value => 2), :M => Dict(:value => 3), :v => Dict{Symbol, Any}(:value => :v_var, :type => Real), :v_for_defaults => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real), :p1 => Dict(:value => nothing)),
:structural_parameters => Dict{Symbol, Dict}(:f => Dict(:value => :sin), :N => Dict(:value => 2), :M => Dict(:value => 3))
:independent_variable => t
:constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant."))
:extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB]
Expand Down
161 changes: 120 additions & 41 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,16 @@ function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
end
end

function unit_handled_variable_value(mod, y, varname)
meta = parse_metadata(mod, y)
varval = if meta isa Nothing || get(meta, VariableUnit, nothing) isa Nothing
varname
else
:($convert_units($(meta[VariableUnit]), $varname))
end
return varval
end

function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
def = nothing, indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing,
type::Type = Real, meta = Dict{DataType, Expr}())
Expand Down Expand Up @@ -222,6 +232,66 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
varclass, where_types, meta)
return var, def, Dict()
end
Expr(:tuple, Expr(:(::), Expr(:ref, a, b...), type), y) || Expr(:tuple, Expr(:ref, a, b...), y) => begin
(@isdefined type) || (type = Real)
varname = Meta.isexpr(a, :call) ? a.args[1] : a
push!(kwargs, Expr(:kw, varname, nothing))
varval = unit_handled_variable_value(mod, y, varname)
if varclass == :parameters
var = :($varname = $first(@parameters $a[$(b...)]::$type = ($varval, $y)))
else
var = :($varname = $first(@variables $a[$(b...)]::$type = ($varval, $y)))
end
#TODO: update `dict` aka `Model.structure` with the metadata
(:($varname...), var), nothing, Dict()
end
Expr(:(=), Expr(:(::), Expr(:ref, a, b...), type), y) || Expr(:(=), Expr(:ref, a, b...), y) => begin
(@isdefined type) || (type = Real)
varname = Meta.isexpr(a, :call) ? a.args[1] : a
if Meta.isexpr(y, :tuple)
varval = unit_handled_variable_value(mod, y, varname)
val, y = (y.args[1], y.args[2:end])
push!(kwargs, Expr(:kw, varname, nothing))
if varclass == :parameters
var = :($varname = $varname === nothing ? $val : $varname;
$varname = $first(@parameters $a[$(b...)]::$type = (
$varval, $(y...))))
else
var = :($varname = $varname === nothing ? $val : $varname;
$varname = $first(@variables $a[$(b...)]::$type = (
$varval, $(y...))))
end
else
push!(kwargs, Expr(:kw, varname, nothing))
if varclass == :parameters
var = :($varname = $varname === nothing ? $y : $varname;
$varname = $first(@parameters $a[$(b...)]::$type = $varname))
else
var = :($varname = $varname === nothing ? $y : $varname;
$varname = $first(@variables $a[$(b...)]::$type = $varname))
end
end
#TODO: update `dict`` aka `Model.structure` with the metadata
(:($varname...), var), nothing, Dict()
end
Expr(:(::), Expr(:ref, a, b...), type) || Expr(:ref, a, b...) => begin
(@isdefined type) || (type = Real)
varname = a isa Expr && a.head == :call ? a.args[1] : a
push!(kwargs, Expr(:kw, varname, nothing))
if varclass == :parameters
var = :($varname = $first(@parameters $a[$(b...)]::$type = $varname))
elseif varclass == :variables
var = :($varname = $first(@variables $a[$(b...)]::$type = $varname))
else
throw("Symbolic array with arbitrary length is not handled for $varclass.
Please open an issue with an example.")
end
dict[varclass] = get!(dict, varclass) do
Dict{Symbol, Dict{Symbol, Any}}()
end
# dict[:kwargs][varname] = dict[varclass][varname] = Dict(:size => b)
(:($varname...), var), nothing, Dict()
end
Expr(:(=), a, b) => begin
Base.remove_linenums!(b)
def, meta = parse_default(mod, b)
Expand Down Expand Up @@ -268,11 +338,6 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
end
return var, def, Dict()
end
Expr(:ref, a, b...) => begin
indices = map(i -> UnitRange(i.args[2], i.args[end]), b)
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types;
def, indices, type, meta)
end
_ => error("$arg cannot be parsed")
end
end
Expand Down Expand Up @@ -380,14 +445,23 @@ function parse_default(mod, a)
end
end

function parse_metadata(mod, a)
function parse_metadata(mod, a::Expr)
MLStyle.@match a begin
Expr(:vect, eles...) => Dict(parse_metadata(mod, e) for e in eles)
Expr(:vect, b...) => Dict(parse_metadata(mod, m) for m in b)
Expr(:tuple, a, b...) => parse_metadata(mod, b)
Expr(:(=), a, b) => Symbolics.option_to_metadata_type(Val(a)) => get_var(mod, b)
_ => error("Cannot parse metadata $a")
end
end

function parse_metadata(mod, metadata::AbstractArray)
ret = Dict()
for m in metadata
merge!(ret, parse_metadata(mod, m))
end
ret
end

function _set_var_metadata!(metadata_with_exprs, a, m, v::Expr)
push!(metadata_with_exprs, m => v)
a
Expand Down Expand Up @@ -645,6 +719,7 @@ function parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs, where_
end

function convert_units(varunits::DynamicQuantities.Quantity, value)
value isa Nothing && return nothing
DynamicQuantities.ustrip(DynamicQuantities.uconvert(
DynamicQuantities.SymbolicUnits.as_quantity(varunits), value))
end
Expand All @@ -656,6 +731,7 @@ function convert_units(
end

function convert_units(varunits::Unitful.FreeUnits, value)
value isa Nothing && return nothing
Unitful.ustrip(varunits, value)
end

Expand All @@ -674,47 +750,50 @@ end
function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types)
vv, def, metadata_with_exprs = parse_variable_def!(
dict, mod, arg, varclass, kwargs, where_types)
name = getname(vv)

varexpr = if haskey(metadata_with_exprs, VariableUnit)
unit = metadata_with_exprs[VariableUnit]
quote
$name = if $name === nothing
$setdefault($vv, $def)
else
try
$setdefault($vv, $convert_units($unit, $name))
catch e
if isa(e, $(DynamicQuantities.DimensionError)) ||
isa(e, $(Unitful.DimensionError))
error("Unable to convert units for \'" * string(:($$vv)) * "\'")
elseif isa(e, MethodError)
error("No or invalid units provided for \'" * string(:($$vv)) *
"\'")
else
rethrow(e)
if !(vv isa Tuple)
name = getname(vv)
varexpr = if haskey(metadata_with_exprs, VariableUnit)
unit = metadata_with_exprs[VariableUnit]
quote
$name = if $name === nothing
$setdefault($vv, $def)
else
try
$setdefault($vv, $convert_units($unit, $name))
catch e
if isa(e, $(DynamicQuantities.DimensionError)) ||
isa(e, $(Unitful.DimensionError))
error("Unable to convert units for \'" * string(:($$vv)) * "\'")
elseif isa(e, MethodError)
error("No or invalid units provided for \'" * string(:($$vv)) *
"\'")
else
rethrow(e)
end
end
end
end
end
else
quote
$name = if $name === nothing
$setdefault($vv, $def)
else
$setdefault($vv, $name)
else
quote
$name = if $name === nothing
$setdefault($vv, $def)
else
$setdefault($vv, $name)
end
end
end
end

metadata_expr = Expr(:block)
for (k, v) in metadata_with_exprs
push!(metadata_expr.args,
:($name = $wrap($set_scalar_metadata($unwrap($name), $k, $v))))
end
metadata_expr = Expr(:block)
for (k, v) in metadata_with_exprs
push!(metadata_expr.args,
:($name = $wrap($set_scalar_metadata($unwrap($name), $k, $v))))
end

push!(varexpr.args, metadata_expr)
return vv isa Num ? name : :($name...), varexpr
push!(varexpr.args, metadata_expr)
return vv isa Num ? name : :($name...), varexpr
else
return vv
end
end

function handle_conditional_vars!(
Expand Down
49 changes: 46 additions & 3 deletions test/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ end
@test all(collect(hasmetadata.(model.l, ModelingToolkit.VariableDescription)))

@test all(lastindex.([model.a2, model.b2, model.d2, model.e2, model.h2]) .== 2)
@test size(model.l) == MockModel.structure[:parameters][:l][:size] == (2, 3)
@test size(model.l) == (2, 3)
@test_broken MockModel.structure[:parameters][:l][:size] == (2, 3)

model = complete(model)
@test getdefault(model.cval) == 1
Expand Down Expand Up @@ -313,7 +314,6 @@ end
@test_throws TypeError TypeModel(; name = :throws, par3 = true)
@test_throws TypeError TypeModel(; name = :throws, par4 = true)
# par7 should be an AbstractArray of BigFloat.
@test_throws MethodError TypeModel(; name = :throws, par7 = rand(Int, 3, 3))

# Test that array types are correctly added.
@named type_model2 = TypeModel(; par5 = rand(BigFloat, 3))
Expand Down Expand Up @@ -474,7 +474,8 @@ using ModelingToolkit: getdefault, scalarize

@named model_with_component_array = ModelWithComponentArray()

@test eval(ModelWithComponentArray.structure[:parameters][:r][:unit]) == eval(u"Ω")
@test_broken eval(ModelWithComponentArray.structure[:parameters][:r][:unit]) ==
eval(u"Ω")
@test lastindex(parameters(model_with_component_array)) == 3

# Test the constant `k`. Manually k's value should be kept in sync here
Expand Down Expand Up @@ -876,3 +877,45 @@ end
end),
false)
end

@testset "Array Length as an Input" begin
@mtkmodel VaryingLengthArray begin
@structural_parameters begin
N
M
end
@parameters begin
p1[1:N]
p2[1:N, 1:M]
end
@variables begin
v1(t)[1:N]
v2(t)[1:N, 1:M]
end
end

@named model = VaryingLengthArray(N = 2, M = 3)
@test length(model.p1) == 2
@test size(model.p2) == (2, 3)
@test length(model.v1) == 2
@test size(model.v2) == (2, 3)

@mtkmodel WithMetadata begin
@structural_parameters begin
N
end
@parameters begin
p_only_default[1:N] = 101
p_only_metadata[1:N], [description = "this only has metadata"]
p_both_default_and_metadata[1:N] = 102,
[description = "this has both default value and metadata"]
end
end

@named with_metadata = WithMetadata(N = 10)
@test getdefault(with_metadata.p_only_default) == 101
@test getdescription(with_metadata.p_only_metadata) == "this only has metadata"
@test getdefault(with_metadata.p_both_default_and_metadata) == 102
@test getdescription(with_metadata.p_both_default_and_metadata) ==
"this has both default value and metadata"
end
Loading