Skip to content

Commit

Permalink
More fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 8, 2023
1 parent 6e64553 commit 256a4fb
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions src/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
end

ddsts = dst.dval
dsrcs = src.dval
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval

if EnzymeCore.EnzymeRules.width(config) == 1
ddsts = (ddsts,)
Expand Down Expand Up @@ -182,7 +182,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
end

ddsts = dst.dval
dsrcs = src.dval
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval

if EnzymeCore.EnzymeRules.width(config) == 1
ddsts = (ddsts,)
Expand Down Expand Up @@ -322,12 +322,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
keep = nothing
end

# Cache idx if its overwritten
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4]
&& !(typeof(src) <: EnzymeCore.Const)
&& !(typeof(dst) <: EnzymeCore.Const)
) ? copy(idx.val) : nothing

return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep)
end

Expand All @@ -336,7 +330,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
val = convert(T, 1/(1-p.val))

ddsts = dst.dval
dsrcs = src.dval
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval

if EnzymeCore.EnzymeRules.width(config) == 1
ddsts = (ddsts,)
Expand Down

0 comments on commit 256a4fb

Please sign in to comment.