We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I reduced the Lux version down to just Enzyme + Reactant
using Reactant, Enzyme function simple_reduce(w, x) y = sum(w * x; dims=1) return sum(y) end w = Reactant.ConcreteRArray(randn(Float32, 10, 3)) x = Reactant.ConcreteRArray(randn(Float32, 3, 5)) simple_reduce_xla = Reactant.compile(simple_reduce, (w, x)) simple_reduce_xla(w, x) # Works function simple_reduce_grad(w, x) dw = Enzyme.make_zero(w) dx = Enzyme.make_zero(x) Enzyme.autodiff( Enzyme.Reverse, simple_reduce, Active, Duplicated(w, dw), Duplicated(x, dx)) return dw, dx end simple_reduce_grad_xla = Reactant.compile(simple_reduce_grad, (w, x)) simple_reduce_grad_xla(w, x)
error: size of operand dimension 0 (5) is not equal to 1 or size of result dimension 0 (10) Pipeline failed
If you replace the function with
function simple_reduce(w, x) return sum(w * x) end
I think it is stemming from the reduce operation being generated incorrectly see LuxDL/Lux.jl#665 (comment) (couldn't reduce that yet)
The text was updated successfully, but these errors were encountered:
Should now be fixed by jll bump
Sorry, something went wrong.
No branches or pull requests
I reduced the Lux version down to just Enzyme + Reactant
If you replace the function with
I think it is stemming from the reduce operation being generated incorrectly see LuxDL/Lux.jl#665 (comment) (couldn't reduce that yet)
The text was updated successfully, but these errors were encountered: