-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explicit Enzyme rules #349
base: master
Are you sure you want to change the base?
Changes from 4 commits
3e3fdc6
7ede350
d171253
dece56c
22471da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ jobs: | |
matrix: | ||
version: | ||
- '1.6' | ||
- '1.10' | ||
- '1' | ||
os: | ||
- ubuntu-latest | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
module BijectorsEnzymeCoreExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using EnzymeCore: | ||
Active, | ||
Const, | ||
Duplicated, | ||
DuplicatedNoNeed, | ||
BatchDuplicated, | ||
BatchDuplicatedNoNeed, | ||
EnzymeRules | ||
using Bijectors: find_alpha | ||
else | ||
using ..EnzymeCore: | ||
Active, | ||
Const, | ||
Duplicated, | ||
DuplicatedNoNeed, | ||
BatchDuplicated, | ||
BatchDuplicatedNoNeed, | ||
EnzymeRules | ||
using ..Bijectors: find_alpha | ||
end | ||
|
||
# Compute a tuple of partial derivatives wrt non-`Const` arguments | ||
# and `nothing`s for `Const` arguments | ||
function ∂find_alpha( | ||
Ω::Real, | ||
wt_y::Union{Const,Active,Duplicated,BatchDuplicated}, | ||
wt_u_hat::Union{Const,Active,Duplicated,BatchDuplicated}, | ||
b::Union{Const,Active,Duplicated,BatchDuplicated}, | ||
) | ||
# We reuse the following term in the computation of the derivatives | ||
Ωpb = Ω + b.val | ||
c = wt_u_hat.val * sech(Ωpb)^2 | ||
cp1 = c + 1 | ||
|
||
∂Ω_∂wt_y = wt_y isa Const ? nothing : oneunit(wt_y.val) / cp1 | ||
∂Ω_∂wt_u_hat = wt_u_hat isa Const ? nothing : -tanh(Ωpb) / cp1 | ||
∂Ω_∂b = b isa Const ? nothing : -c / cp1 | ||
|
||
return (∂Ω_∂wt_y, ∂Ω_∂wt_u_hat, ∂Ω_∂b) | ||
end | ||
|
||
# `muladd` for partial derivatives that can deal with `nothing` derivatives | ||
_muladd_partial(::Nothing, ::Const, x::Union{Real,Tuple{Vararg{Real}},Nothing}) = x | ||
_muladd_partial(x::Real, y::Duplicated, z::Real) = muladd(x, y.dval, z) | ||
_muladd_partial(x::Real, y::Duplicated, ::Nothing) = x * y.dval | ||
function _muladd_partial(x::Real, y::BatchDuplicated{<:Real,N}, z::NTuple{N,Real}) where {N} | ||
let x = x | ||
map((a, b) -> muladd(x, a, b), y.dval, z) | ||
end | ||
end | ||
_muladd_partial(x::Real, y::BatchDuplicated, ::Nothing) = map(Base.Fix1(*, x), y.dval) | ||
|
||
function EnzymeRules.forward( | ||
::Const{typeof(find_alpha)}, | ||
::Type{RT}, | ||
wt_y::Union{Const,Duplicated,BatchDuplicated}, | ||
wt_u_hat::Union{Const,Duplicated,BatchDuplicated}, | ||
b::Union{Const,Duplicated,BatchDuplicated}, | ||
) where {RT<:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed}} | ||
# Check that the types of the activities are consistent | ||
if !( | ||
RT <: Union{Const,Duplicated,DuplicatedNoNeed} && | ||
wt_y isa Union{Const,Duplicated} && | ||
wt_u_hat isa Union{Const,Duplicated} && | ||
b isa Union{Const,Duplicated} | ||
) && !( | ||
RT <: Union{Const,BatchDuplicated,BatchDuplicatedNoNeed} && | ||
wt_y isa Union{Const,BatchDuplicated} && | ||
wt_u_hat isa Union{Const,BatchDuplicated} && | ||
b isa Union{Const,BatchDuplicated} | ||
) | ||
throw(ArgumentError("inconsistent activities")) | ||
end | ||
|
||
# Compute primal value | ||
Ω = find_alpha(wt_y.val, wt_u_hat.val, b.val) | ||
|
||
# Early exit if no derivatives are requested | ||
if RT <: Const | ||
return Ω | ||
end | ||
|
||
Ω̇ = if wt_y isa Const && wt_u_hat isa Const && b isa Const | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. with the bump to 0.13 you can probably do !needs_shadow || (the const conditions) |
||
# Trivial case: All partial derivatives are 0 | ||
zero(Ω) | ||
else | ||
# In all other cases we have to compute the partial derivatives | ||
∂Ω_∂wt_y, ∂Ω_∂wt_u_hat, ∂Ω_∂b = ∂find_alpha(Ω, wt_y, wt_u_hat, b) | ||
_muladd_partial( | ||
∂Ω_∂wt_y, | ||
wt_y, | ||
_muladd_partial(∂Ω_∂wt_u_hat, wt_u_hat, _muladd_partial(∂Ω_∂b, b, nothing)), | ||
) | ||
end | ||
|
||
if RT <: Duplicated | ||
@assert Ω̇ isa Real | ||
return Duplicated(Ω, Ω̇) | ||
elseif RT <: DuplicatedNoNeed | ||
@assert Ω̇ isa Real | ||
return Ω̇ | ||
elseif RT <: BatchDuplicated | ||
@assert Ω̇ isa Tuple{Vararg{Real}} | ||
return BatchDuplicated(Ω, Ω̇) | ||
else | ||
@assert RT <: BatchDuplicatedNoNeed | ||
@assert Ω̇ isa Tuple{Vararg{Real}} | ||
return Ω̇ | ||
end | ||
end | ||
|
||
struct Zero{T} | ||
x::T | ||
end | ||
(f::Zero)(_) = zero(f.x) | ||
|
||
function EnzymeRules.augmented_primal( | ||
config::EnzymeRules.Config, | ||
::Const{typeof(find_alpha)}, | ||
::Type{RT}, | ||
wt_y::Union{Const,Active}, | ||
wt_u_hat::Union{Const,Active}, | ||
b::Union{Const,Active}, | ||
) where {RT<:Union{Const,Active}} | ||
# Only compute the the original return value if it is actually needed | ||
Ω = | ||
if EnzymeRules.needs_primal(config) || | ||
EnzymeRules.needs_shadow(config) || | ||
!(RT <: Const || (wt_y isa Const && wt_u_hat isa Const && b isa Const)) | ||
find_alpha(wt_y.val, wt_u_hat.val, b.val) | ||
else | ||
nothing | ||
end | ||
|
||
tape = if RT <: Const || (wt_y isa Const && wt_u_hat isa Const && b isa Const) | ||
# Trivial case: No differentiation or all derivatives are 0 | ||
# Thus no tape is needed | ||
nothing | ||
else | ||
# Derivatives with respect to at least one argument needed | ||
# They are computed in the reverse pass, and therefore the original return is cached | ||
# In principle, the partial derivatives could be computed here and be cached | ||
# But Enzyme only executes the reverse pass once, | ||
# thus this would not increase efficiency but instead more values would have to be cached | ||
Ω | ||
end | ||
|
||
# Ensure that we follow the interface requirements of `augmented_primal` | ||
primal = EnzymeRules.needs_primal(config) ? Ω : nothing | ||
shadow = if EnzymeRules.needs_shadow(config) | ||
if EnzymeRules.width(config) === 1 | ||
zero(Ω) | ||
else | ||
ntuple(Zero(Ω), Val(EnzymeRules.width(config))) | ||
end | ||
else | ||
nothing | ||
end | ||
|
||
return EnzymeRules.AugmentedReturn(primal, shadow, tape) | ||
end | ||
|
||
struct ZeroOrNothing{N} end | ||
(::ZeroOrNothing)(::Const) = nothing | ||
(::ZeroOrNothing{1})(x::Active) = zero(x.val) | ||
(::ZeroOrNothing{N})(x::Active) where {N} = ntuple(Zero(x.val), Val{N}()) | ||
|
||
function EnzymeRules.reverse( | ||
config::EnzymeRules.Config, | ||
::Const{typeof(find_alpha)}, | ||
::Type{<:Const}, | ||
::Nothing, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this relies on some type stability guarantees, I'd probalby just mark it as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is a fair assumption to make, but also might as well leave it |
||
wt_y::Union{Const,Active}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. at this point why bother having an annotation at all here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to ensure when and with which types of arguments the function is called, to be sure that it is actually doing the correct thing and does not give silently wrong results. Additionally, IMO it makes it easier to read and understand the code (non-annotated arguments and very generic types were quite confusing to me when I tried to read and learn from rules in Enzyme). |
||
wt_u_hat::Union{Const,Active}, | ||
b::Union{Const,Active}, | ||
) | ||
# Trivial case: Nothing to be differentiated (return activity is `Const`) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These should also likely be extended for batch size != 1. Basically in these cases active returns are just an ntuple of the relevant derivatives (in this case zero) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added support for batches in reverse mode (or at least I tried to 😄). How would I go about testing it? AFAICT There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The easiest way is probably to make a wrapper function that takes and returns variables by reference then calls this function. Then you can use batchduplicated There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to add such a test in 22471da but now julia is aborted in CI due to an LLVM assertion error: https://github.com/TuringLang/Bijectors.jl/actions/runs/11788862392/job/32836666863?pr=349#step:5:488 Do you have any idea what is going and what the problem is? |
||
return map(ZeroOrNothing{EnzymeRules.width(config)}(), (wt_y, wt_u_hat, b)) | ||
end | ||
function EnzymeRules.reverse( | ||
::EnzymeRules.Config, | ||
::Const{typeof(find_alpha)}, | ||
::Active, | ||
::Nothing, | ||
::Const, | ||
::Const, | ||
::Const, | ||
) | ||
# Trivial case: Tape does not exist sice all partial derivatives are 0 | ||
return (nothing, nothing, nothing) | ||
end | ||
|
||
struct MulPartialOrNothing{T<:Union{Real,Tuple{Vararg{Real}}}} | ||
x::T | ||
end | ||
(::MulPartialOrNothing)(::Nothing) = nothing | ||
(f::MulPartialOrNothing{<:Real})(∂f_∂x::Real) = ∂f_∂x * f.x | ||
function (f::MulPartialOrNothing{<:NTuple{N,Real}})(∂f_∂x::Real) where {N} | ||
return map(Base.Fix1(*, ∂f_∂x), f.x) | ||
end | ||
|
||
function EnzymeRules.reverse( | ||
::EnzymeRules.Config, | ||
::Const{typeof(find_alpha)}, | ||
ΔΩ::Active, | ||
Ω::Real, | ||
wt_y::Union{Const,Active}, | ||
wt_u_hat::Union{Const,Active}, | ||
b::Union{Const,Active}, | ||
) | ||
# Tape must be `nothing` if all arguments are `Const` | ||
@assert !(wt_y isa Const && wt_u_hat isa Const && b isa Const) | ||
|
||
# Compute partial derivatives | ||
∂Ω_∂xs = ∂find_alpha(Ω, wt_y, wt_u_hat, b) | ||
return map(MulPartialOrNothing(ΔΩ.val), ∂Ω_∂xs) | ||
end | ||
|
||
end # module |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Segfaults on older Julia versions, probably never supported | ||
# TODO: Enable tests on Julia >= 1.11 when updating to Enzyme 0.13 | ||
if v"1.10" <= VERSION < v"1.11" | ||
@testset "Enzyme: Bijectors.find_alpha" begin | ||
x = randn() | ||
y = expm1(randn()) | ||
z = randn() | ||
|
||
@testset "forward" begin | ||
# No batches | ||
@testset for RT in (Const, Duplicated, DuplicatedNoNeed), | ||
Tx in (Const, Duplicated), | ||
Ty in (Const, Duplicated), | ||
Tz in (Const, Duplicated) | ||
|
||
test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz)) | ||
end | ||
|
||
# Batches | ||
@testset for RT in (Const, BatchDuplicated, BatchDuplicatedNoNeed), | ||
Tx in (Const, BatchDuplicated), | ||
Ty in (Const, BatchDuplicated), | ||
Tz in (Const, BatchDuplicated) | ||
|
||
test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz)) | ||
end | ||
end | ||
@testset "reverse" begin | ||
@testset for RT in (Const, Active), | ||
Tx in (Const, Active), | ||
Ty in (Const, Active), | ||
Tz in (Const, Active) | ||
|
||
test_reverse(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz)) | ||
end | ||
end | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is using the rule system for Enzyme <=0.12, which was changed in Enzyme 0.13. [or equivalently EnzymeCore 0.7 and 0.8 iirc]. We should probably use the latter.
The breaking change of relevance is that forward mode rules now have a config as the first arg [similar to ReverseMode], with utility functions that specify whether the original return and/or derivative returns are requested.
For example: https://github.com/EnzymeAD/Enzyme.jl/blob/31df08b376634e2926e1d147c67284b20c493c22/src/internal_rules.jl#L858
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I know, this is intentional since the master branch uses Enzyme 0.12 and I wanted to keep the update to Enzyme 0.13 and the switch from the imported to the explicit rules separate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw this, and one thing that confused me was: What's the purpose of the return type/activity (Const/Duplicated/DuplicatedNoNeed) if it's solely based on
needs_primal(config)
andneeds_shadow(config)
what's returned? Why does the argument exist at all if it is ignored?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, it was previously impossible to represent a function with a constant return, whose primal result was not required by enzyme
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simultaneously marking a return DuplicatedNoNeed is no longer how to specify that the primal is not needed (eg outside of rules you would do autodiff(ForwardWithPrimal, Duplicated, …)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I think I still don't fully understand. Based on the
sort!
example it seems the only use case would be to dispatch on batched/non-batched mode? But that could be done solely by dispatching on argumentxs
, it seems? Both versions (batched/non-batched) completely ignore whether the return type is Const/Duplicated/...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So that’s one use of it.
But basically they represent different (though related) pieces of information.
const:
Result is not differentiated wrt.
It can either be required to be returned (needs primal) or not
Duplicated:
Result is differentiated wrt
The primal can be required to be returned or not (previously this was represented in dupnoneed vs duplicated, now duplicated is always passed in?
The shadow can be required to be returned or not (needs shadow). Prior to 0.13 it was always required to be returned, even if not needed.