diff --git a/src/enzyme.jl b/src/enzyme.jl index 0e1b75d8a..1d4d0c20c 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -122,7 +122,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end ddsts = dst.dval - dsrcs = src.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeCore.EnzymeRules.width(config) == 1 ddsts = (ddsts,) @@ -182,7 +182,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end ddsts = dst.dval - dsrcs = src.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeCore.EnzymeRules.width(config) == 1 ddsts = (ddsts,) @@ -322,12 +322,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ keep = nothing end - # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] - && !(typeof(src) <: EnzymeCore.Const) - && !(typeof(dst) <: EnzymeCore.Const) - ) ? copy(idx.val) : nothing - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep) end @@ -336,7 +330,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN val = convert(T, 1/(1-p.val)) ddsts = dst.dval - dsrcs = src.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeCore.EnzymeRules.width(config) == 1 ddsts = (ddsts,)