diff --git a/src/enzyme.jl b/src/enzyme.jl index 0e1b75d8a..1d4d0c20c 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -122,7 +122,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end ddsts = dst.dval - dsrcs = src.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeCore.EnzymeRules.width(config) == 1 ddsts = (ddsts,) @@ -182,7 +182,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end ddsts = dst.dval - dsrcs = src.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeCore.EnzymeRules.width(config) == 1 ddsts = (ddsts,) @@ -322,12 +322,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ keep = nothing end - # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] - && !(typeof(src) <: EnzymeCore.Const) - && !(typeof(dst) <: EnzymeCore.Const) - ) ? copy(idx.val) : nothing - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep) end @@ -336,7 +330,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN val = convert(T, 1/(1-p.val)) ddsts = dst.dval - dsrcs = src.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeCore.EnzymeRules.width(config) == 1 ddsts = (ddsts,) diff --git a/test/conv.jl b/test/conv.jl index 4412abf3a..717cf0b82 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -895,7 +895,7 @@ end EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tw, Tw) || continue - EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (idx, EnzymeCore.Const)) + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (cdims, EnzymeCore.Const)) end end end