Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for typed function symbolics #1270

Merged
merged 6 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ StaticArraysCore = "1.4"
SymPy = "2.2"
SymbolicIndexingInterface = "0.3.14"
SymbolicLimits = "0.2.2"
SymbolicUtils = "2, 3"
SymbolicUtils = "3.7"
TermInterface = "2"
julia = "1.10"

Expand Down
118 changes: 103 additions & 15 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ function _parse_vars(macroname, type, x, transform=identity)
# x = 1, [connect = flow; unit = u"m^3/s"]
if Meta.isexpr(v, :(=))
v, val = v.args
# defaults with metadata for function variables
if Meta.isexpr(val, :block)
Base.remove_linenums!(val)
val = only(val.args)
end
if Meta.isexpr(val, :tuple) && length(val.args) == 2 && isoption(val.args[2])
options = val.args[2].args
val = val.args[1]
Expand All @@ -124,7 +129,7 @@ function _parse_vars(macroname, type, x, transform=identity)
isruntime, v = unwrap_runtime_var(v)
iscall = Meta.isexpr(v, :call)
isarray = Meta.isexpr(v, :ref)
if iscall && Meta.isexpr(v.args[1], :ref)
if iscall && Meta.isexpr(v.args[1], :ref) && !call_args_are_function(map(last∘unwrap_runtime_var, @view v.args[2:end]))
@warn("The variable syntax $v is deprecated. Use $(Expr(:ref, Expr(:call, v.args[1].args[1], v.args[2]), v.args[1].args[2:end]...)) instead.
The former creates an array of functions, while the latter creates an array valued function.
The deprecated syntax will cause an error in the next major release of Symbolics.
Expand Down Expand Up @@ -155,35 +160,61 @@ function _parse_vars(macroname, type, x, transform=identity)
return ex
end

call_args_are_function(_) = false
function call_args_are_function(call_args::AbstractArray)
!isempty(call_args) && (call_args[end] == :(..) || all(Base.Fix2(Meta.isexpr, :(::)), call_args))
end

function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, prop, transform, isruntime)
ndim = :($length(($(indices...),)))
vname = !isruntime ? Meta.quot(lhs) : lhs
if call_args[1] == :..
ex = :($CallWithMetadata($Sym{$FnType{Tuple, Array{$type, $ndim}}}($vname)))
if call_args_are_function(call_args)
vname, fntype = function_name_and_type(lhs)
# name was already unwrapped before calling this function and is of the form $x
if isruntime
_vname = vname
else
# either no ::fnType or $x::fnType
vname, fntype = function_name_and_type(lhs)
isruntime, vname = unwrap_runtime_var(vname)
if isruntime
_vname = vname
else
_vname = Meta.quot(vname)
end
end
argtypes = arg_types_from_call_args(call_args)
ex = :($CallWithMetadata($Sym{$FnType{$argtypes, Array{$type, $ndim}, $(fntype...)}}($_vname)))
else
ex = :($Sym{$FnType{Tuple, Array{$type, $ndim}}}($vname)(map($unwrap, ($(call_args...),))...))
vname = lhs
if isruntime
_vname = vname
else
_vname = Meta.quot(vname)
end
ex = :($Sym{$FnType{Tuple, Array{$type, $ndim}}}($_vname)(map($unwrap, ($(call_args...),))...))
end
ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),)))

if val !== nothing
ex = :($setdefaultval($ex, $val))
end
ex = setprops_expr(ex, prop, macroname, Meta.quot(lhs))
ex = setprops_expr(ex, prop, macroname, Meta.quot(vname))
#ex = :($scalarize_getindex($ex))

ex = :($wrap($ex))

ex = :($transform($ex))
if isruntime
lhs = gensym(lhs)
vname = gensym(Symbol(vname))
end
lhs, :($lhs = $ex)
vname, :($vname = $ex)
end

function construct_vars(macroname, v, type, call_args, val, prop, transform, isruntime)
issym = v isa Symbol
isarray = isa(v, Expr) && v.head == :ref
isarray = !isruntime && Meta.isexpr(v, :ref)
if isarray
# this can't be an array of functions, since that was handled by `construct_dep_array_vars`
var_name = v.args[1]
if Meta.isexpr(var_name, :(::))
var_name, type′ = var_name.args
Expand All @@ -192,6 +223,22 @@ function construct_vars(macroname, v, type, call_args, val, prop, transform, isr
isruntime, var_name = unwrap_runtime_var(var_name)
indices = v.args[2:end]
expr = _construct_array_vars(macroname, isruntime ? var_name : Meta.quot(var_name), type, call_args, val, prop, indices...)
elseif call_args_are_function(call_args)
var_name, fntype = function_name_and_type(v)
# name was already unwrapped before calling this function and is of the form $x
if isruntime
vname = var_name
else
# either no ::fnType or $x::fnType
var_name, fntype = function_name_and_type(v)
isruntime, var_name = unwrap_runtime_var(var_name)
if isruntime
vname = var_name
else
vname = Meta.quot(var_name)
end
end
expr = construct_var(macroname, fntype == () ? vname : Expr(:(::), vname, fntype[1]), type, call_args, val, prop)
else
var_name = v
if Meta.isexpr(v, :(::))
Expand All @@ -200,7 +247,7 @@ function construct_vars(macroname, v, type, call_args, val, prop, transform, isr
end
expr = construct_var(macroname, isruntime ? var_name : Meta.quot(var_name), type, call_args, val, prop)
end
lhs = isruntime ? gensym(var_name) : var_name
lhs = isruntime ? gensym(Symbol(var_name)) : var_name
rhs = :($transform($expr))
lhs, :($lhs = $rhs)
end
Expand Down Expand Up @@ -249,15 +296,54 @@ function Base.show(io::IO, c::CallWithMetadata)
print(io, "⋆")
end

struct CallWithParent end

function (f::CallWithMetadata)(args...)
metadata(unwrap(f.f(map(unwrap, args)...)), metadata(f))
setmetadata(metadata(unwrap(f.f(map(unwrap, args)...)), metadata(f)), CallWithParent, f)
end

Base.isequal(a::CallWithMetadata, b::CallWithMetadata) = isequal(a.f, b.f)

function arg_types_from_call_args(call_args)
if length(call_args) == 1 && only(call_args) == :..
return Tuple
end
Ts = map(call_args) do arg
if arg == :..
Vararg
elseif arg isa Expr && arg.head == :(::)
if length(arg.args) == 1
arg.args[1]
elseif arg.args[1] == :..
:(Vararg{$(arg.args[2])})
else
arg.args[2]
end
else
error("Invalid call argument $arg")
end
end
return :(Tuple{$(Ts...)})
end

function function_name_and_type(var_name)
if var_name isa Expr && Meta.isexpr(var_name, :(::), 2)
var_name.args[1], (var_name.args[2],)
else
var_name, ()
end
end

function construct_var(macroname, var_name, type, call_args, val, prop)
expr = if call_args === nothing
:($Sym{$type}($var_name))
elseif !isempty(call_args) && call_args[end] == :..
:($CallWithMetadata($Sym{$FnType{Tuple, $type}}($var_name)))
elseif call_args_are_function(call_args)
# function syntax is (x::TFunc)(.. or ::TArg1, ::TArg2)::TRet
# .. is Vararg
# (..)::ArgT is Vararg{ArgT}
var_name, fntype = function_name_and_type(var_name)
argtypes = arg_types_from_call_args(call_args)
:($CallWithMetadata($Sym{$FnType{$argtypes, $type, $(fntype...)}}($var_name)))
else
:($Sym{$FnType{NTuple{$(length(call_args)), Any}, $type}}($var_name)($(map(x->:($value($x)), call_args)...)))
end
Expand All @@ -283,9 +369,11 @@ function _construct_array_vars(macroname, var_name, type, call_args, val, prop,
expr = if call_args === nothing
ex = :($Sym{Array{$type, $ndim}}($var_name))
:($setmetadata($ex, $ArrayShapeCtx, ($(indices...),)))
elseif !isempty(call_args) && call_args[end] == :..
elseif call_args_are_function(call_args)
need_scalarize = true
ex = :($Sym{Array{$FnType{Tuple, $type}, $ndim}}($var_name))
var_name, fntype = function_name_and_type(var_name)
argtypes = arg_types_from_call_args(call_args)
ex = :($Sym{Array{$FnType{$argtypes, $type, $(fntype...)}, $ndim}}($var_name))
ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),)))
:($map($CallWithMetadata, $ex))
else
Expand Down
135 changes: 134 additions & 1 deletion test/macro.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Symbolics
import Symbolics: getsource, getdefaultval, wrap, unwrap, getname
import Symbolics: CallWithMetadata, getsource, getdefaultval, wrap, unwrap, getname
import SymbolicUtils: Term, symtype, FnType, BasicSymbolic, promote_symtype
using LinearAlgebra
using Test
Expand Down Expand Up @@ -238,3 +238,136 @@ spam(x) = 2x
sym = spam([a, 2a])
@test sym isa Num
@test unwrap(sym) isa BasicSymbolic{Real}

fn_defaults = [print, min, max, identity, (+), (-), max, sum, vcat, (*)]
fn_names = [Symbol(:f, i) for i in 1:10]

struct VariableFoo end
Symbolics.option_to_metadata_type(::Val{:foo}) = VariableFoo

function test_all_functions(fns)
f1, f2, f3, f4, f5, f6, f7, f8, f9, f10 = fns
@variables x y::Int z::Function w[1:3, 1:3] v[1:3, 1:3]::String
@test f1 isa CallWithMetadata{FnType{Tuple, Real}}
@test all(x -> symtype(x) <: Real, [f1(), f1(1), f1(x), f1(x, y), f1(x, y, x+y)])
@test f2 isa CallWithMetadata{FnType{Tuple{Any, Vararg}, Int}}
@test all(x -> symtype(x) <: Int, [f2(1), f2(z), f2(x), f2(x, y), f2(x, y, x+y)])
@test_throws ErrorException f2()
@test f3 isa CallWithMetadata{FnType{Tuple, Real, typeof(max)}}
@test all(x -> symtype(x) <: Real, [f3(), f3(1), f3(x), f3(x, y), f3(x, y, x+y)])
@test f4 isa CallWithMetadata{FnType{Tuple{Int}, Real}}
@test all(x -> symtype(x) <: Real, [f4(1), f4(y), f4(2y)])
@test_throws ErrorException f4(x)
@test f5 isa CallWithMetadata{FnType{Tuple{Int, Vararg{Int}}, Real}}
@test all(x -> symtype(x) <: Real, [f5(1), f5(y), f5(y, y), f5(2, 3)])
@test_throws ErrorException f5(x)
@test f6 isa CallWithMetadata{FnType{Tuple{Int, Int}, Int}}
@test all(x -> symtype(x) <: Int, [f6(1, 1), f6(y, y), f6(1, y), f6(y, 1)])
@test_throws ErrorException f6()
@test_throws ErrorException f6(1)
@test_throws ErrorException f6(x, y)
@test_throws ErrorException f6(y)
@test f7 isa CallWithMetadata{FnType{Tuple{Int, Int}, Int, typeof(max)}}
# call behavior tested by f6
@test f8 isa CallWithMetadata{FnType{Tuple{Function, Vararg}, Real, typeof(sum)}}
@test all(x -> symtype(x) <: Real, [f8(z), f8(z, x), f8(identity), f8(identity, x)])
@test_throws ErrorException f8(x)
@test_throws ErrorException f8(1)
@test f9 isa CallWithMetadata{FnType{Tuple, Vector{Real}}}
@test all(x -> symtype(unwrap(x)) <: Vector{Real} && size(x) == (3,), [f9(), f9(1), f9(x), f9(x + y), f9(z), f9(1, x)])
@test f10 isa CallWithMetadata{FnType{Tuple{Matrix{<:Real}, Matrix{<:Real}}, Matrix{Real}, typeof(*)}}
@test all(x -> symtype(unwrap(x)) <: Matrix{Real} && size(x) == (3, 3), [f10(w, w), f10(w, ones(3, 3)), f10(ones(3, 3), ones(3, 3)), f10(w + w, w)])
@test_throws ErrorException f10(w, v)
end

function test_functions_defaults(fns)
for (fn, def) in zip(fns, fn_defaults)
@test Symbolics.getdefaultval(fn, nothing) == def
end
end

function test_functions_metadata(fns)
for (i, fn) in enumerate(fns)
@test Symbolics.getmetadata(fn, VariableFoo, nothing) == i
end
end

fns = @test_nowarn @variables begin
f1(..)
f2(::Any, ..)::Int
(f3::typeof(max))(..)
f4(::Int)
f5(::Int, (..)::Int)
f6(::Int, ::Int)::Int
(f7::typeof(max))(::Int, ::Int)::Int
(f8::typeof(sum))(::Function, ..)
f9(..)[1:3]
(f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3]
# f11[1:3](::Int)::Int
end

test_all_functions(fns)

fns = @test_nowarn @variables begin
f1(..) = fn_defaults[1]
f2(::Any, ..)::Int = fn_defaults[2]
(f3::typeof(max))(..) = fn_defaults[3]
f4(::Int) = fn_defaults[4]
f5(::Int, (..)::Int) = fn_defaults[5]
f6(::Int, ::Int)::Int = fn_defaults[6]
(f7::typeof(max))(::Int, ::Int)::Int = fn_defaults[7]
(f8::typeof(sum))(::Function, ..) = fn_defaults[8]
f9(..)[1:3] = fn_defaults[9]
(f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3] = fn_defaults[10]
end

test_all_functions(fns)
test_functions_defaults(fns)

fns = @variables begin
f1(..) = fn_defaults[1], [foo = 1]
f2(::Any, ..)::Int = fn_defaults[2], [foo = 2;]
(f3::typeof(max))(..) = fn_defaults[3], [foo = 3;]
f4(::Int) = fn_defaults[4], [foo = 4;]
f5(::Int, (..)::Int) = fn_defaults[5], [foo = 5;]
f6(::Int, ::Int)::Int = fn_defaults[6], [foo = 6;]
(f7::typeof(max))(::Int, ::Int)::Int = fn_defaults[7], [foo = 7;]
(f8::typeof(sum))(::Function, ..) = fn_defaults[8], [foo = 8;]
f9(..)[1:3] = fn_defaults[9], [foo = 9;]
(f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3] = fn_defaults[10], [foo = 10;]
end

test_all_functions(fns)
test_functions_defaults(fns)
test_functions_metadata(fns)

fns = @test_nowarn @variables begin
f1(..), [foo = 1,]
f2(::Any, ..)::Int, [foo = 2,]
(f3::typeof(max))(..), [foo = 3,]
f4(::Int), [foo = 4,]
f5(::Int, (..)::Int), [foo = 5,]
f6(::Int, ::Int)::Int, [foo = 6,]
(f7::typeof(max))(::Int, ::Int)::Int, [foo = 7,]
(f8::typeof(sum))(::Function, ..), [foo = 8,]
f9(..)[1:3], [foo = 9,]
(f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3], [foo = 10,]
end

test_all_functions(fns)
test_functions_metadata(fns)

fns = @test_nowarn @variables begin
$(fn_names[1])(..)
$(fn_names[2])(::Any, ..)::Int
($(fn_names[3])::typeof(max))(..)
$(fn_names[4])(::Int)
$(fn_names[5])(::Int, (..)::Int)
$(fn_names[6])(::Int, ::Int)::Int
($(fn_names[7])::typeof(max))(::Int, ::Int)::Int
($(fn_names[8])::typeof(sum))(::Function, ..)
$(fn_names[9])(..)[1:3]
($(fn_names[10])::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3]
end

test_all_functions(fns)
Loading