Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP fix higher order. 3rd works #217

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ primal(z::ZeroTangent) = ZeroTangent()

first_partial(x) = partial(x, 1)

shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U} =
UniformBundle{N-1, <:Any}(UniformBundle{1, B}(b.primal, b.tangent.val),
UniformBundle{1, U}(b.tangent.val, b.tangent.val))
function shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U}
UniformBundle{1, <:Any}(
UniformBundle{N-1, B}(b.primal, b.tangent.val),
UniformBundle{N-1, U}(b.tangent.val, b.tangent.val)
)
end

#==
function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
# N.B: This depends on the special properties of the canonical tangent index order
Base.@constprop :aggressive function _sdown(i::Int64)
Expand All @@ -24,14 +28,22 @@ function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
ExplicitTangentBundle{1}(b.primal, (partial(b, 1),)),
ntuple(_sdown, 1<<(N-1)-1))
end
==#

function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
#==
Base.@constprop :aggressive function _sdown(i::Int64)
ExplicitTangentBundle{1}(b.tangent.coeffs[i], (b.tangent.coeffs[i+1],))
end
TaylorBundle{N-1}(
ExplicitTangentBundle{1}(b.primal, (b.tangent.coeffs[1],)),
ntuple(_sdown, N-1))
==#

TaylorBundle{1}(
TaylorBundle{N-1}(b.primal, b.tangent.coeffs[1:end-1]),
(TaylorBundle{N-1}(b.tangent.coeffs[1], b.tangent.coeffs[2:end]),)
)
end


Expand All @@ -55,16 +67,16 @@ function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
end

function taylor_compatible(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
partial(r, 1)[1] = primal(r)[2] || return false
return all(1:N-1) do ii
partial(r, 1)[1] == primal(r)[2] || return false
return all(1:N-1) do i
partial(r, i+1)[1] == partial(r, i)[2]
end
end
function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
the_primal = primal(r)[1]
if taylor_compatible(r)
the_partials = ntuple(N+1) do i
if ii <= N
if i <= N
partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2])
else # ii = N+1
partial(r, i-1)[2]
Expand Down Expand Up @@ -150,14 +162,15 @@ function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}
end
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)


# Special case rules for performance
#==
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
s = primal(s)
ExplicitTangentBundle{N}(getfield(primal(x), s),
map(x->lifted_getfield(x, s), x.tangent.partials))
end


@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
s = primal(s)
ExplicitTangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
Expand All @@ -169,6 +182,7 @@ end
TaylorBundle{N}(getfield(primal(x), s),
map(y->lifted_getfield(y, s), x.tangent.coeffs))
end
==#

@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}, inbounds::ATB{N}) where {N}
s = primal(s)
Expand Down Expand Up @@ -250,7 +264,7 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N,
r === nothing && return ZeroBundle{N}(nothing)
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
end

#==
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N}
r = Base.indexed_iterate(destructure(t), primal(i))
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
Expand All @@ -264,6 +278,7 @@ end
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N}
∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1))
end
==#

function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::ZeroBundle) where {N}
field_ind = primal(i)
Expand All @@ -282,4 +297,4 @@ function (this::∂☆{N})(f::ZeroBundle{N, Core.IntrinsicFunction}, args::ATB{N
else
error("Missing rule for intrinsic function $ff")
end
end
end
10 changes: 2 additions & 8 deletions test/forward.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
module forward_tests
using Diffractor
using Diffractor: var"'", ∂⃖, DiffractorRuleConfig, ZeroBundle
using ChainRules
using ChainRulesCore
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
using LinearAlgebra

using Test

const fwd = Diffractor.PrimeDerivativeFwd
const bwd = Diffractor.PrimeDerivativeBack



# Minimal 2-nd order forward smoke test
@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin),
Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)
Expand All @@ -26,7 +19,8 @@ let var"'" = Diffractor.PrimeDerivativeFwd
@test recursive_sin'(1.0) == cos(1.0)
@test recursive_sin''(1.0) == -sin(1.0)

@test_broken recursive_sin'''(1.0) == -cos(1.0)
@test recursive_sin'''(1.0) == -cos(1.0)

@test_broken recursive_sin''''(1.0) == sin(1.0)
@test_broken recursive_sin'''''(1.0) == cos(1.0)
@test_broken recursive_sin''''''(1.0) == -sin(1.0)
Expand Down
Loading