From 7188d03a2a0b00b1f297e138b62660243dfec484 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 21 Sep 2023 12:49:14 +0800 Subject: [PATCH 1/2] WIP fix higher order. 3rd works --- src/stage1/forward.jl | 34 +++++++++++++++++++++++++--------- test/forward.jl | 10 ++-------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 29fcba33..ef3994d6 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -11,10 +11,14 @@ primal(z::ZeroTangent) = ZeroTangent() first_partial(x) = partial(x, 1) -shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U} = - UniformBundle{N-1, <:Any}(UniformBundle{1, B}(b.primal, b.tangent.val), - UniformBundle{1, U}(b.tangent.val, b.tangent.val)) +function shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U} + UniformBundle{1, <:Any}( + UniformBundle{N-1, B}(b.primal, b.tangent.val), + UniformBundle{N-1, U}(b.tangent.val, b.tangent.val) + ) +end +#== function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B} # N.B: This depends on the special properties of the canonical tangent index order Base.@constprop :aggressive function _sdown(i::Int64) @@ -24,14 +28,22 @@ function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B} ExplicitTangentBundle{1}(b.primal, (partial(b, 1),)), ntuple(_sdown, 1<<(N-1)-1)) end +==# function shuffle_down(b::TaylorBundle{N, B}) where {N, B} + #== Base.@constprop :aggressive function _sdown(i::Int64) ExplicitTangentBundle{1}(b.tangent.coeffs[i], (b.tangent.coeffs[i+1],)) end TaylorBundle{N-1}( ExplicitTangentBundle{1}(b.primal, (b.tangent.coeffs[1],)), ntuple(_sdown, N-1)) + ==# + + TaylorBundle{1}( + TaylorBundle{N-1}(b.primal, b.tangent.coeffs[1:end-1]), + (TaylorBundle{N-1}(b.tangent.coeffs[1], b.tangent.coeffs[2:end]),) + ) end @@ -55,8 +67,8 @@ function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N} end function taylor_compatible(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2} - partial(r, 1)[1] = primal(r)[2] || return false - return all(1:N-1) do ii + partial(r, 1)[1] == primal(r)[2] || return false + return all(1:N-1) do i partial(r, i+1)[1] == partial(r, i)[2] end end @@ -64,7 +76,7 @@ function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2} the_primal = primal(r)[1] if taylor_compatible(r) the_partials = ntuple(N+1) do i - if ii <= N + if i <= N partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2]) else # ii = N+1 partial(r, i-1)[2] @@ -150,14 +162,16 @@ function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N} end (::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...) - +ChainRulesCore.frule((_, ẋ)) # Special case rules for performance +#== @Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N} s = primal(s) ExplicitTangentBundle{N}(getfield(primal(x), s), map(x->lifted_getfield(x, s), x.tangent.partials)) end + @Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::ATB{N}, inbounds::ATB{N}) where {N} s = primal(s) ExplicitTangentBundle{N}(getfield(primal(x), s, primal(inbounds)), @@ -169,6 +183,7 @@ end TaylorBundle{N}(getfield(primal(x), s), map(y->lifted_getfield(y, s), x.tangent.coeffs)) end +==# @Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}, inbounds::ATB{N}) where {N} s = primal(s) @@ -250,7 +265,7 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, r === nothing && return ZeroBundle{N}(nothing) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end - +#== function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N} r = Base.indexed_iterate(destructure(t), primal(i)) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) @@ -264,6 +279,7 @@ end function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N} ∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1)) end +==# function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::ZeroBundle) where {N} field_ind = primal(i) @@ -282,4 +298,4 @@ function (this::∂☆{N})(f::ZeroBundle{N, Core.IntrinsicFunction}, args::ATB{N else error("Missing rule for intrinsic function $ff") end -end +end \ No newline at end of file diff --git a/test/forward.jl b/test/forward.jl index 48ccc76b..7b19fb50 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -1,18 +1,11 @@ module forward_tests using Diffractor -using Diffractor: var"'", ∂⃖, DiffractorRuleConfig, ZeroBundle using ChainRules using ChainRulesCore using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad using LinearAlgebra - using Test -const fwd = Diffractor.PrimeDerivativeFwd -const bwd = Diffractor.PrimeDerivativeBack - - - # Minimal 2-nd order forward smoke test @test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) @@ -26,7 +19,8 @@ let var"'" = Diffractor.PrimeDerivativeFwd @test recursive_sin'(1.0) == cos(1.0) @test recursive_sin''(1.0) == -sin(1.0) - @test_broken recursive_sin'''(1.0) == -cos(1.0) + @test recursive_sin'''(1.0) == -cos(1.0) + @test_broken recursive_sin''''(1.0) == sin(1.0) @test_broken recursive_sin'''''(1.0) == cos(1.0) @test_broken recursive_sin''''''(1.0) == -sin(1.0) From 75f3425d48681a7bc0867cae6fdbbc142439ffe9 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 21 Sep 2023 14:08:34 +0800 Subject: [PATCH 2/2] Update src/stage1/forward.jl --- src/stage1/forward.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index ef3994d6..7d3f41fc 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -162,7 +162,6 @@ function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N} end (::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...) -ChainRulesCore.frule((_, ẋ)) # Special case rules for performance #== @Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}