Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not aggressively apply dot #136

Merged
merged 3 commits into from
Jul 30, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Decapodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using Base.Iterators
import Unicode

export normalize_unicode, DerivOp, append_dot,
SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, Decapode, NamedDecapode, SummationDecapode, fill_names!, expand_operators, add_constant!, add_parameter, infer_types!, resolve_overloads!, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D,
SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, Decapode, NamedDecapode, SummationDecapode, fill_names!, dot_rename!, expand_operators, add_constant!, add_parameter, infer_types!, resolve_overloads!, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D,
Term, Var, Judgement, Eq, AppCirc1, AppCirc2, App1, App2, Plus, Tan, term, parse_decapode,
VectorForm, PhysicsState, findname, findnode,
compile, compile_env, gensim, evalsim, closest_point, flat_op,
Expand Down
46 changes: 45 additions & 1 deletion src/decapodeacset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,53 @@ function fill_names!(d::AbstractNamedDecapode)
for e in incident(d, :∂ₜ, :op1)
s = d[e,:src]
t = d[e, :tgt]
String(d[t,:name])[1] != '•' && continue
d[t, :name] = append_dot(d[s,:name])
end
return d
d
end

""" find_dep_and_order(d::AbstractNamedDecapode)

Find the order of each tangent variable in the Decapode, and the index of the variable that it is dependent on. Returns a tuple of (dep, order), both of which respecting the order in which incident(d, :∂ₜ, :op1) returns Vars.
"""
function find_dep_and_order(d::AbstractNamedDecapode)
dep = d[incident(d, :∂ₜ, :op1), :src]
order = ones(Int, nparts(d, :TVar))
found = true
while found
found = false
for i in parts(d, :TVar)
deps = incident(d, :∂ₜ, :op1) ∩ incident(d, dep[i], :tgt)
if !isempty(deps)
dep[i] = d[first(deps), :src]
order[i] += 1
found = true
end
end
end
(dep, order)
end

""" dot_rename!(d::AbstractNamedDecapode)

Rename tangent variables by their depending variable appended with a dot.
e.g. If D == ∂ₜ(C), then rename D to Ċ.

If a tangent variable updates multiple vars, choose one arbitrarily.
e.g. If D == ∂ₜ(C) and D == ∂ₜ(B), then rename D to either Ċor B ̇.
lukem12345 marked this conversation as resolved.
Show resolved Hide resolved
"""
function dot_rename!(d::AbstractNamedDecapode)
dep, order = find_dep_and_order(d)
for (i,e) in enumerate(incident(d, :∂ₜ, :op1))
t = d[e, :tgt]
name = d[dep[i],:name]
for _ in 1:order[i]
name = append_dot(name)
end
d[t, :name] = name
end
d
end

function make_sum_mult_unique!(d::AbstractNamedDecapode)
Expand Down
28 changes: 23 additions & 5 deletions test/diag2dwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,36 @@ import Decapodes: DecaExpr
E == ∂ₜ(D)
end
pt3 = SummationDecapode(parse_decapode(ParseTest3))
@test pt3[:name] == [:, :Ċ̇, :C]
@test pt3[:name] == [:D, :E, :C]
@test pt3[:incl] == [1,2]
@test pt3[:src] == [3, 1]
@test pt3[:tgt] == [1, 2]

dot_rename!(pt3)
@test pt3[:name] == [Symbol('C'*'\U0307'), Symbol('C'*'\U0307'*'\U0307'), :C]

# TODO: We should eventually recognize this equivalence
#= ParseTest4 = quote
D == D + C
D + C == C
end
pt4 = SummationDecapode(parse_decapode(ParseTest4)) =#

# Do not rename TVars if they are given a name.
pt5 = SummationDecapode(parse_decapode(quote
X::Form0{Point}
V::Form0{Point}

k::Constant{Point}

∂ₜ(X) == V
∂ₜ(V) == -1*k*(X)
end))

@test pt5[:name] == [:X, :V, :k, :mult_1, Symbol('V'*'\U0307'), Symbol("-1")]
dot_rename!(pt5)
@test pt5[:name] == [:X, Symbol('X'*'\U0307'), :k, :mult_1, Symbol('X'*'\U0307'*'\U0307'), Symbol("-1")]

end
Deca = quote
(A, B, C)::Form0
Expand Down Expand Up @@ -389,7 +407,7 @@ end

# We use set equality because we do not care about the order of the Var table.
names_types_1 = Set(zip(t1[:name], t1[:type]))
names_types_expected_1 = Set([(:, :Form0)])
names_types_expected_1 = Set([(:C, :Form0)])
@test issetequal(names_types_1, names_types_expected_1)

# The type of the src of ∂ₜ is inferred.
Expand All @@ -398,11 +416,11 @@ end
∂ₜ(C) == C
end
t2 = SummationDecapode(parse_decapode(Test2))
t2[only(incident(t2, :, :name)), :type] = :Form0
t2[only(incident(t2, :C, :name)), :type] = :Form0
infer_types!(t2)

names_types_2 = Set(zip(t2[:name], t2[:type]))
names_types_expected_2 = Set([(:, :Form0)])
names_types_expected_2 = Set([(:C, :Form0)])
@test issetequal(names_types_2, names_types_expected_2)

# The type of the tgt of d is inferred.
Expand Down Expand Up @@ -1144,4 +1162,4 @@ end

op1 = [:⋆₂, :∂ₜ, :d₁]
end
end
end
2 changes: 1 addition & 1 deletion test/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ Test8Res = average_rewrite(Test8)
Test8Expected = @acset SummationDecapode{Any, Any, Symbol} begin
Var = 3
type = Any[:Form1, :Form2, :Form0]
name = [:D₁, :D₂, :D₁̇ ]
name = [:D₁, :D₂, :F]

TVar = 1
incl = [3]
Expand Down