Skip to content

Commit

Permalink
fix: return types for Enzyme ForwardDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 19, 2024
1 parent 1c9e3e8 commit edf2d71
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ function overload_autodiff(
end
end
for (i, act) in enumerate(activity)
if act == enzyme_out || (reverse && (act == enzyme_dup || act == enzyme_dupnoneed))
if act == enzyme_out || act == enzyme_dup || act == enzyme_dupnoneed
if width == 1
push!(outtys, in_tys[i])
else
Expand Down
9 changes: 2 additions & 7 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y))

res1 = @jit(
fwd(
set_abi(Forward, Reactant.ReactantABI),
Forward,
Duplicated,
ConcreteRArray(ones(3, 2)),
ConcreteRArray(3.1 * ones(3, 2)),
Expand Down Expand Up @@ -42,12 +42,7 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y))
@test typeof(ores1) == Tuple{}

res1 = @jit(
fwd(
set_abi(Forward, Reactant.ReactantABI),
Const,
ConcreteRArray(ones(3, 2)),
ConcreteRArray(3.1 * ones(3, 2)),
)
fwd(Forward, Const, ConcreteRArray(ones(3, 2)), ConcreteRArray(3.1 * ones(3, 2)))
)

@test typeof(res1) == Tuple{}
Expand Down

0 comments on commit edf2d71

Please sign in to comment.