Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit Enzyme rules #349

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/AD.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
matrix:
version:
- '1.6'
- '1.10'
- '1'
os:
- ubuntu-latest
Expand Down
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
Expand All @@ -36,7 +35,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BijectorsDistributionsADExt = "DistributionsAD"
BijectorsEnzymeExt = ["Enzyme", "EnzymeCore"]
BijectorsEnzymeCoreExt = "EnzymeCore"
BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsMooncakeExt = "Mooncake"
Expand Down
222 changes: 222 additions & 0 deletions ext/BijectorsEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
module BijectorsEnzymeCoreExt

if isdefined(Base, :get_extension)
using EnzymeCore:
Active,
Const,
Duplicated,
DuplicatedNoNeed,
BatchDuplicated,
BatchDuplicatedNoNeed,
EnzymeRules
using Bijectors: find_alpha
else
using ..EnzymeCore:
Active,
Const,
Duplicated,
DuplicatedNoNeed,
BatchDuplicated,
BatchDuplicatedNoNeed,
EnzymeRules
using ..Bijectors: find_alpha
end

# Compute a tuple of partial derivatives wrt non-`Const` arguments
# and `nothing`s for `Const` arguments
function ∂find_alpha(
Ω::Real,
wt_y::Union{Const,Active,Duplicated,BatchDuplicated},
wt_u_hat::Union{Const,Active,Duplicated,BatchDuplicated},
b::Union{Const,Active,Duplicated,BatchDuplicated},
)
# We reuse the following term in the computation of the derivatives
Ωpb = Ω + b.val
c = wt_u_hat.val * sech(Ωpb)^2
cp1 = c + 1

∂Ω_∂wt_y = wt_y isa Const ? nothing : oneunit(wt_y.val) / cp1
∂Ω_∂wt_u_hat = wt_u_hat isa Const ? nothing : -tanh(Ωpb) / cp1
∂Ω_∂b = b isa Const ? nothing : -c / cp1

return (∂Ω_∂wt_y, ∂Ω_∂wt_u_hat, ∂Ω_∂b)
end

# `muladd` for partial derivatives that can deal with `nothing` derivatives
_muladd_partial(::Nothing, ::Const, x::Union{Real,Tuple{Vararg{Real}},Nothing}) = x
_muladd_partial(x::Real, y::Duplicated, z::Real) = muladd(x, y.dval, z)
_muladd_partial(x::Real, y::Duplicated, ::Nothing) = x * y.dval
function _muladd_partial(x::Real, y::BatchDuplicated{<:Real,N}, z::NTuple{N,Real}) where {N}
let x = x
map((a, b) -> muladd(x, a, b), y.dval, z)
end
end
_muladd_partial(x::Real, y::BatchDuplicated, ::Nothing) = map(Base.Fix1(*, x), y.dval)

function EnzymeRules.forward(
::Const{typeof(find_alpha)},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is using the rule system for Enzyme <=0.12, which was changed in Enzyme 0.13. [or equivalently EnzymeCore 0.7 and 0.8 iirc]. We should probably use the latter.

The breaking change of relevance is that forward mode rules now have a config as the first arg [similar to ReverseMode], with utility functions that specify whether the original return and/or derivative returns are requested.

For example: https://github.com/EnzymeAD/Enzyme.jl/blob/31df08b376634e2926e1d147c67284b20c493c22/src/internal_rules.jl#L858

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is using the rule system for Enzyme <=0.12, which was changed in Enzyme 0.13.

Yes, I know, this is intentional since the master branch uses Enzyme 0.12 and I wanted to keep the update to Enzyme 0.13 and the switch from the imported to the explicit rules separate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The breaking change of relevance is that forward mode rules now have a config as the first arg [similar to ReverseMode], with utility functions that specify whether the original return and/or derivative returns are requested.

I saw this, and one thing that confused me was: What's the purpose of the return type/activity (Const/Duplicated/DuplicatedNoNeed) if it's solely based on needs_primal(config) and needs_shadow(config) what's returned? Why does the argument exist at all if it is ignored?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, it was previously impossible to represent a function with a constant return, whose primal result was not required by enzyme

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simultaneously marking a return DuplicatedNoNeed is no longer how to specify that the primal is not needed (eg outside of rules you would do autodiff(ForwardWithPrimal, Duplicated, …)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think I still don't fully understand. Based on the sort! example it seems the only use case would be to dispatch on batched/non-batched mode? But that could be done solely by dispatching on argument xs, it seems? Both versions (batched/non-batched) completely ignore whether the return type is Const/Duplicated/...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that’s one use of it.

But basically they represent different (though related) pieces of information.

const:
Result is not differentiated wrt.
It can either be required to be returned (needs primal) or not
Duplicated:
Result is differentiated wrt
The primal can be required to be returned or not (previously this was represented in dupnoneed vs duplicated, now duplicated is always passed in?
The shadow can be required to be returned or not (needs shadow). Prior to 0.13 it was always required to be returned, even if not needed.

::Type{RT},
wt_y::Union{Const,Duplicated,BatchDuplicated},
wt_u_hat::Union{Const,Duplicated,BatchDuplicated},
b::Union{Const,Duplicated,BatchDuplicated},
) where {RT<:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed}}
# Check that the types of the activities are consistent
if !(
RT <: Union{Const,Duplicated,DuplicatedNoNeed} &&
wt_y isa Union{Const,Duplicated} &&
wt_u_hat isa Union{Const,Duplicated} &&
b isa Union{Const,Duplicated}
) && !(
RT <: Union{Const,BatchDuplicated,BatchDuplicatedNoNeed} &&
wt_y isa Union{Const,BatchDuplicated} &&
wt_u_hat isa Union{Const,BatchDuplicated} &&
b isa Union{Const,BatchDuplicated}
)
throw(ArgumentError("inconsistent activities"))
end

# Compute primal value
Ω = find_alpha(wt_y.val, wt_u_hat.val, b.val)

# Early exit if no derivatives are requested
if RT <: Const
return Ω
end

Ω̇ = if wt_y isa Const && wt_u_hat isa Const && b isa Const
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the bump to 0.13 you can probably do !needs_shadow || (the const conditions)

# Trivial case: All partial derivatives are 0
zero(Ω)
else
# In all other cases we have to compute the partial derivatives
∂Ω_∂wt_y, ∂Ω_∂wt_u_hat, ∂Ω_∂b = ∂find_alpha(Ω, wt_y, wt_u_hat, b)
_muladd_partial(
∂Ω_∂wt_y,
wt_y,
_muladd_partial(∂Ω_∂wt_u_hat, wt_u_hat, _muladd_partial(∂Ω_∂b, b, nothing)),
)
end

if RT <: Duplicated
@assert Ω̇ isa Real
return Duplicated(Ω, Ω̇)
elseif RT <: DuplicatedNoNeed
@assert Ω̇ isa Real
return Ω̇
elseif RT <: BatchDuplicated
@assert Ω̇ isa Tuple{Vararg{Real}}
return BatchDuplicated(Ω, Ω̇)
else
@assert RT <: BatchDuplicatedNoNeed
@assert Ω̇ isa Tuple{Vararg{Real}}
return Ω̇
end
end

struct Zero{T}
x::T
end
(f::Zero)(_) = zero(f.x)

function EnzymeRules.augmented_primal(
config::EnzymeRules.Config,
::Const{typeof(find_alpha)},
::Type{RT},
wt_y::Union{Const,Active},
wt_u_hat::Union{Const,Active},
b::Union{Const,Active},
) where {RT<:Union{Const,Active}}
# Only compute the the original return value if it is actually needed
Ω =
if EnzymeRules.needs_primal(config) ||
EnzymeRules.needs_shadow(config) ||
!(RT <: Const || (wt_y isa Const && wt_u_hat isa Const && b isa Const))
find_alpha(wt_y.val, wt_u_hat.val, b.val)
else
nothing
end

tape = if RT <: Const || (wt_y isa Const && wt_u_hat isa Const && b isa Const)
# Trivial case: No differentiation or all derivatives are 0
# Thus no tape is needed
nothing
else
# Derivatives with respect to at least one argument needed
# They are computed in the reverse pass, and therefore the original return is cached
# In principle, the partial derivatives could be computed here and be cached
# But Enzyme only executes the reverse pass once,
# thus this would not increase efficiency but instead more values would have to be cached
Ω
end

# Ensure that we follow the interface requirements of `augmented_primal`
primal = EnzymeRules.needs_primal(config) ? Ω : nothing
shadow = if EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) === 1
zero(Ω)
else
ntuple(Zero(Ω), Val(EnzymeRules.width(config)))
end
else
nothing
end

return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end

struct ZeroOrNothing{N} end
(::ZeroOrNothing)(::Const) = nothing
(::ZeroOrNothing{1})(x::Active) = zero(x.val)
(::ZeroOrNothing{N})(x::Active) where {N} = ntuple(Zero(x.val), Val{N}())

function EnzymeRules.reverse(
config::EnzymeRules.Config,
::Const{typeof(find_alpha)},
::Type{<:Const},
::Nothing,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this relies on some type stability guarantees, I'd probalby just mark it as tape, without annotation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a fair assumption to make, but also might as well leave it

wt_y::Union{Const,Active},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point why bother having an annotation at all here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to ensure when and with which types of arguments the function is called, to be sure that it is actually doing the correct thing and does not give silently wrong results. Additionally, IMO it makes it easier to read and understand the code (non-annotated arguments and very generic types were quite confusing to me when I tried to read and learn from rules in Enzyme).

wt_u_hat::Union{Const,Active},
b::Union{Const,Active},
)
# Trivial case: Nothing to be differentiated (return activity is `Const`)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should also likely be extended for batch size != 1.

Basically in these cases active returns are just an ntuple of the relevant derivatives (in this case zero)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added support for batches in reverse mode (or at least I tried to 😄). How would I go about testing it? AFAICT EnzymeTestUtils.test_reverse won't test batched mode since all activities are either Const or Active.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The easiest way is probably to make a wrapper function that takes and returns variables by reference then calls this function. Then you can use batchduplicated

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to add such a test in 22471da but now julia is aborted in CI due to an LLVM assertion error: https://github.com/TuringLang/Bijectors.jl/actions/runs/11788862392/job/32836666863?pr=349#step:5:488 Do you have any idea what is going and what the problem is?

return map(ZeroOrNothing{EnzymeRules.width(config)}(), (wt_y, wt_u_hat, b))
end
function EnzymeRules.reverse(
::EnzymeRules.Config,
::Const{typeof(find_alpha)},
::Active,
::Nothing,
::Const,
::Const,
::Const,
)
# Trivial case: Tape does not exist sice all partial derivatives are 0
return (nothing, nothing, nothing)
end

struct MulPartialOrNothing{T<:Union{Real,Tuple{Vararg{Real}}}}
x::T
end
(::MulPartialOrNothing)(::Nothing) = nothing
(f::MulPartialOrNothing{<:Real})(∂f_∂x::Real) = ∂f_∂x * f.x
function (f::MulPartialOrNothing{<:NTuple{N,Real}})(∂f_∂x::Real) where {N}
return map(Base.Fix1(*, ∂f_∂x), f.x)
end

function EnzymeRules.reverse(
::EnzymeRules.Config,
::Const{typeof(find_alpha)},
ΔΩ::Active,
Ω::Real,
wt_y::Union{Const,Active},
wt_u_hat::Union{Const,Active},
b::Union{Const,Active},
)
# Tape must be `nothing` if all arguments are `Const`
@assert !(wt_y isa Const && wt_u_hat isa Const && b isa Const)

# Compute partial derivatives
∂Ω_∂xs = ∂find_alpha(Ω, wt_y, wt_u_hat, b)
return map(MulPartialOrNothing(ΔΩ.val), ∂Ω_∂xs)
end

end # module
18 changes: 0 additions & 18 deletions ext/BijectorsEnzymeExt.jl

This file was deleted.

2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -33,6 +34,7 @@ Combinatorics = "1.0.2"
Compat = "3.46, 4.2"
DistributionsAD = "0.6.3"
Enzyme = "0.12.22"
EnzymeTestUtils = "0.1.8"
FillArrays = "1"
FiniteDifferences = "0.11, 0.12"
ForwardDiff = "0.10.12"
Expand Down
38 changes: 38 additions & 0 deletions test/ad/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Segfaults on older Julia versions, probably never supported
# TODO: Enable tests on Julia >= 1.11 when updating to Enzyme 0.13
if v"1.10" <= VERSION < v"1.11"
@testset "Enzyme: Bijectors.find_alpha" begin
x = randn()
y = expm1(randn())
z = randn()

@testset "forward" begin
# No batches
@testset for RT in (Const, Duplicated, DuplicatedNoNeed),
Tx in (Const, Duplicated),
Ty in (Const, Duplicated),
Tz in (Const, Duplicated)

test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
end

# Batches
@testset for RT in (Const, BatchDuplicated, BatchDuplicatedNoNeed),
Tx in (Const, BatchDuplicated),
Ty in (Const, BatchDuplicated),
Tz in (Const, BatchDuplicated)

test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
end
end
@testset "reverse" begin
@testset for RT in (Const, Active),
Tx in (Const, Active),
Ty in (Const, Active),
Tz in (Const, Active)

test_reverse(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
end
end
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ChainRulesTestUtils
using Combinatorics
using DistributionsAD
using Enzyme
using EnzymeTestUtils
using FiniteDifferences
using ForwardDiff
using Functors
Expand Down Expand Up @@ -68,6 +69,7 @@ end

if GROUP == "All" || GROUP == "AD"
include("ad/chainrules.jl")
include("ad/enzyme.jl")
include("ad/flows.jl")
include("ad/pd.jl")
include("ad/corr.jl")
Expand Down
Loading