diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 4b71a1341..f5e4475a2 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -283,7 +283,7 @@ function overload_autodiff( end end for (i, act) in enumerate(activity) - if act == enzyme_out || (reverse && (act == enzyme_dup || act == enzyme_dupnoneed)) + if act == enzyme_out || act == enzyme_dup || act == enzyme_dupnoneed if width == 1 push!(outtys, in_tys[i]) else diff --git a/test/autodiff.jl b/test/autodiff.jl index 5cf1726d0..842050413 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -11,7 +11,7 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) res1 = @jit( fwd( - set_abi(Forward, Reactant.ReactantABI), + Forward, Duplicated, ConcreteRArray(ones(3, 2)), ConcreteRArray(3.1 * ones(3, 2)), @@ -42,12 +42,7 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) @test typeof(ores1) == Tuple{} res1 = @jit( - fwd( - set_abi(Forward, Reactant.ReactantABI), - Const, - ConcreteRArray(ones(3, 2)), - ConcreteRArray(3.1 * ones(3, 2)), - ) + fwd(Forward, Const, ConcreteRArray(ones(3, 2)), ConcreteRArray(3.1 * ones(3, 2))) ) @test typeof(res1) == Tuple{}