Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: update minimum version of Enzyme to 0.13
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 21, 2024
1 parent 79ed8fe commit 6dd7701
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ ChainRulesCore = "1.24"
Compat = "4.15.0"
CpuId = "0.3"
DispatchDoctor = "0.4.12"
EnzymeCore = "0.7.7"
EnzymeCore = "0.8"
FastClosures = "0.3.2"
ForwardDiff = "0.10.36"
Hwloc = "3.2"
Expand Down
5 changes: 3 additions & 2 deletions src/impl/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,18 +213,19 @@ end
# Enzyme works for all of these except `gelu`.
# See https://github.com/EnzymeAD/Enzyme.jl/issues/1671
function EnzymeRules.augmented_primal(
cfg::EnzymeRules.ConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu)},
cfg::EnzymeRules.RevConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu)},
::Type{<:EnzymeCore.Active}, x::EnzymeCore.Active{<:Number})
primal = EnzymeRules.needs_primal(cfg) ? func.val(x.val) : nothing
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(
::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu)},
::EnzymeRules.RevConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu)},
dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number})
return (dret.val * ∇gelu(x.val),)
end

# FIXME: ForwardRules changed in EnzymeCore 0.8
function EnzymeRules.forward(
::EnzymeCore.Const{typeof(gelu)}, ::Type{<:EnzymeCore.Duplicated},
x::EnzymeCore.Duplicated{<:Number})
Expand Down
4 changes: 2 additions & 2 deletions src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ end
for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!)
@eval begin
function EnzymeRules.augmented_primal(
cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))},
cfg::EnzymeRules.RevConfigWidth, ::EnzymeCore.Const{typeof($(func))},
::Type{RT}, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}},
A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}},
B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT}
Expand All @@ -155,7 +155,7 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!)
end

function EnzymeRules.reverse(
cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))},
cfg::EnzymeRules.RevConfigWidth, ::EnzymeCore.Const{typeof($(func))},
::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}},
A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}},
B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT}
Expand Down
5 changes: 3 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ CRC.@non_differentiable safe_minimum(::Any...)
macro enzyme_alternative(f₁, f₂)
return esc(quote
function EnzymeRules.augmented_primal(
::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))},
::EnzymeRules.RevConfig, ::EnzymeCore.Const{typeof($(f₁))},
::Type{RT}, args...) where {RT}
fwd, rev = EnzymeCore.autodiff_thunk(
EnzymeCore.ReverseSplitWithPrimal, EnzymeCore.Const{typeof($(f₂))},
Expand All @@ -245,11 +245,12 @@ macro enzyme_alternative(f₁, f₂)
end

function EnzymeRules.reverse(
::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))},
::EnzymeRules.RevConfig, ::EnzymeCore.Const{typeof($(f₁))},
::Type{RT}, (tape, rev), args...) where {RT}
return only(rev(EnzymeCore.Const($(f₂)), args..., tape))
end

# FIXME: ForwardRules changed in EnzymeCore 0.8
function EnzymeRules.forward(
::EnzymeCore.Const{typeof($(f₁))}, ::Type{RT}, args...) where {RT}
EnzymeCore.autodiff(EnzymeCore.Forward, EnzymeCore.Const($(f₂)), RT, args...)
Expand Down
6 changes: 3 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ BLISBLAS = "0.1"
BenchmarkTools = "1.5"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.16"
Enzyme = "0.12.26"
EnzymeCore = "0.7.7"
Enzyme = "0.13.1"
EnzymeCore = "0.8"
ExplicitImports = "1.9.0"
ForwardDiff = "0.10.36"
Hwloc = "3.2"
InteractiveUtils = "<0.0.1, 1"
JLArrays = "0.1.5"
LuxTestUtils = "1.2"
LuxTestUtils = "1.2.1"
MKL = "0.7"
MLDataDevices = "1.0.0"
NNlib = "0.9.21"
Expand Down

0 comments on commit 6dd7701

Please sign in to comment.