From 784e7e950c2db17ce5f9cbd3cdbdb85bbfe8e5ed Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Sun, 30 Jul 2023 16:17:19 -0400 Subject: [PATCH 1/3] Do not aggressively apply dot --- src/Decapodes.jl | 2 +- src/decapodeacset.jl | 23 ++++++++++++++++++++++- test/diag2dwd.jl | 28 +++++++++++++++++++++++----- test/rewrite.jl | 2 +- 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/Decapodes.jl b/src/Decapodes.jl index eac1ea4f..711a82ae 100644 --- a/src/Decapodes.jl +++ b/src/Decapodes.jl @@ -15,7 +15,7 @@ using Base.Iterators import Unicode export normalize_unicode, DerivOp, append_dot, - SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, Decapode, NamedDecapode, SummationDecapode, fill_names!, expand_operators, add_constant!, add_parameter, infer_types!, resolve_overloads!, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, + SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, Decapode, NamedDecapode, SummationDecapode, fill_names!, dot_rename!, expand_operators, add_constant!, add_parameter, infer_types!, resolve_overloads!, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, Term, Var, Judgement, Eq, AppCirc1, AppCirc2, App1, App2, Plus, Tan, term, parse_decapode, VectorForm, PhysicsState, findname, findnode, compile, compile_env, gensim, evalsim, closest_point, flat_op, diff --git a/src/decapodeacset.jl b/src/decapodeacset.jl index de70f222..1293ea11 100644 --- a/src/decapodeacset.jl +++ b/src/decapodeacset.jl @@ -47,11 +47,32 @@ function fill_names!(d::AbstractNamedDecapode) for e in incident(d, :∂ₜ, :op1) s = d[e,:src] t = d[e, :tgt] + String(d[t,:name])[1] != '•' && continue d[t, :name] = append_dot(d[s,:name]) end - return d + d end +""" dot_rename!(d::AbstractNamedDecapode) + +Rename tangent variables by their depending variable appended with a dot. +e.g. If D == ∂ₜ(C), then rename D to Ċ. + +If a tangent variable updates multiple vars, choose one arbitrarily. +e.g. If D == ∂ₜ(C) and D == ∂ₜ(B), then rename D to either Ċor B ̇. +""" +function dot_rename!(d::AbstractNamedDecapode) + # TODO: This method only works because higher order derivatives have always + # appeared later in the Var table. + for e in incident(d, :∂ₜ, :op1) + s = d[e,:src] + t = d[e, :tgt] + d[t, :name] = append_dot(d[s,:name]) + end + d +end + + function make_sum_mult_unique!(d::AbstractNamedDecapode) snum = 1 mnum = 1 diff --git a/test/diag2dwd.jl b/test/diag2dwd.jl index dea29b21..d6d7671e 100644 --- a/test/diag2dwd.jl +++ b/test/diag2dwd.jl @@ -196,11 +196,14 @@ import Decapodes: DecaExpr E == ∂ₜ(D) end pt3 = SummationDecapode(parse_decapode(ParseTest3)) - @test pt3[:name] == [:Ċ, :Ċ̇, :C] + @test pt3[:name] == [:D, :E, :C] @test pt3[:incl] == [1,2] @test pt3[:src] == [3, 1] @test pt3[:tgt] == [1, 2] + dot_rename!(pt3) + @test pt3[:name] == [Symbol('C'*'\U0307'), Symbol('C'*'\U0307'*'\U0307'), :C] + # TODO: We should eventually recognize this equivalence #= ParseTest4 = quote D == D + C @@ -208,6 +211,21 @@ import Decapodes: DecaExpr end pt4 = SummationDecapode(parse_decapode(ParseTest4)) =# + # Do not rename TVars if they are given a name. + pt5 = SummationDecapode(parse_decapode(quote + X::Form0{Point} + V::Form0{Point} + + k::Constant{Point} + + ∂ₜ(X) == V + ∂ₜ(V) == -1*k*(X) + end)) + + @test pt5[:name] == [:X, :V, :k, :mult_1, Symbol('V'*'\U0307'), Symbol("-1")] + dot_rename!(pt5) + @test pt5[:name] == [:X, Symbol('X'*'\U0307'), :k, :mult_1, Symbol('X'*'\U0307'*'\U0307'), Symbol("-1")] + end Deca = quote (A, B, C)::Form0 @@ -389,7 +407,7 @@ end # We use set equality because we do not care about the order of the Var table. names_types_1 = Set(zip(t1[:name], t1[:type])) - names_types_expected_1 = Set([(:Ċ, :Form0)]) + names_types_expected_1 = Set([(:C, :Form0)]) @test issetequal(names_types_1, names_types_expected_1) # The type of the src of ∂ₜ is inferred. @@ -398,11 +416,11 @@ end ∂ₜ(C) == C end t2 = SummationDecapode(parse_decapode(Test2)) - t2[only(incident(t2, :Ċ, :name)), :type] = :Form0 + t2[only(incident(t2, :C, :name)), :type] = :Form0 infer_types!(t2) names_types_2 = Set(zip(t2[:name], t2[:type])) - names_types_expected_2 = Set([(:Ċ, :Form0)]) + names_types_expected_2 = Set([(:C, :Form0)]) @test issetequal(names_types_2, names_types_expected_2) # The type of the tgt of d is inferred. @@ -1144,4 +1162,4 @@ end op1 = [:⋆₂, :∂ₜ, :d₁] end -end \ No newline at end of file +end diff --git a/test/rewrite.jl b/test/rewrite.jl index 64ec7996..e27cefe8 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -335,7 +335,7 @@ Test8Res = average_rewrite(Test8) Test8Expected = @acset SummationDecapode{Any, Any, Symbol} begin Var = 3 type = Any[:Form1, :Form2, :Form0] - name = [:D₁, :D₂, :D₁̇ ] + name = [:D₁, :D₂, :F] TVar = 1 incl = [3] From 061f7f962ea06dca4c897b4094699b53275ec829 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Sun, 30 Jul 2023 18:36:57 -0400 Subject: [PATCH 2/3] Do not depend on op1 order --- src/decapodeacset.jl | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/src/decapodeacset.jl b/src/decapodeacset.jl index 1293ea11..fceb0453 100644 --- a/src/decapodeacset.jl +++ b/src/decapodeacset.jl @@ -53,6 +53,28 @@ function fill_names!(d::AbstractNamedDecapode) d end +""" find_dep_and_order(d::AbstractNamedDecapode) + +Find the order of each tangent variable in the Decapode, and the index of the variable that it is dependent on. Returns a tuple of (dep, order), both of which respecting the order in which incident(d, :∂ₜ, :op1) returns Vars. +""" +function find_dep_and_order(d::AbstractNamedDecapode) + dep = d[incident(d, :∂ₜ, :op1), :src] + order = ones(Int, nparts(d, :TVar)) + found = true + while found + found = false + for i in parts(d, :TVar) + deps = incident(d, :∂ₜ, :op1) ∩ incident(d, dep[i], :tgt) + if !isempty(deps) + dep[i] = d[first(deps), :src] + order[i] += 1 + found = true + end + end + end + (dep, order) +end + """ dot_rename!(d::AbstractNamedDecapode) Rename tangent variables by their depending variable appended with a dot. @@ -62,17 +84,18 @@ If a tangent variable updates multiple vars, choose one arbitrarily. e.g. If D == ∂ₜ(C) and D == ∂ₜ(B), then rename D to either Ċor B ̇. """ function dot_rename!(d::AbstractNamedDecapode) - # TODO: This method only works because higher order derivatives have always - # appeared later in the Var table. - for e in incident(d, :∂ₜ, :op1) - s = d[e,:src] + dep, order = find_dep_and_order(d) + for (i,e) in enumerate(incident(d, :∂ₜ, :op1)) t = d[e, :tgt] - d[t, :name] = append_dot(d[s,:name]) + name = d[dep[i],:name] + for _ in 1:order[i] + name = append_dot(name) + end + d[t, :name] = name end d end - function make_sum_mult_unique!(d::AbstractNamedDecapode) snum = 1 mnum = 1 From e0c3b9c3057367b29c8992f24b9707b7a089d3ac Mon Sep 17 00:00:00 2001 From: Luke Morris <70283489+lukem12345@users.noreply.github.com> Date: Sun, 30 Jul 2023 19:09:47 -0400 Subject: [PATCH 3/3] Fix mangled docstring whitespace --- src/decapodeacset.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/decapodeacset.jl b/src/decapodeacset.jl index fceb0453..af2bc0c1 100644 --- a/src/decapodeacset.jl +++ b/src/decapodeacset.jl @@ -81,7 +81,7 @@ Rename tangent variables by their depending variable appended with a dot. e.g. If D == ∂ₜ(C), then rename D to Ċ. If a tangent variable updates multiple vars, choose one arbitrarily. -e.g. If D == ∂ₜ(C) and D == ∂ₜ(B), then rename D to either Ċor B ̇. +e.g. If D == ∂ₜ(C) and D == ∂ₜ(B), then rename D to either Ċ or B ̇. """ function dot_rename!(d::AbstractNamedDecapode) dep, order = find_dep_and_order(d)