-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enzyme.jl compatibility #389
Comments
hi @CarloLucibello as per our discussion I've set up a working example for Flux and then used this example to extend to using Flux, Random, Enzyme, GraphNeuralNetworks
rng = Random.default_rng()
loss(model, x) = sum(model(g, g.x))
model = GNNChain(GCNConv(2=>5),
BatchNorm(5),
x -> relu.(x),
Dense(5, 4))
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)
g.ndata.x = x
grads_zygote = Flux.gradient(model->loss(model, x), model)[1]
dx = grads_enzyme = Flux.fmap(model) do x
x isa Array ? zero(x) : x
end
Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, grads_enzyme), Duplicated(x, dx)) the last line results in error message indicating lack of ERROR: MethodError: no method matching Duplicated(::Int64, ::GNNChain{Tuple{GCNConv{…}, BatchNorm{…}, var"#47#48", Dense{…}}})
Closest candidates are:
Duplicated(::T1, ::T1) where T1
@ EnzymeCore ~/.julia/packages/EnzymeCore/XBDTI/src/EnzymeCore.jl:64
Duplicated(::T1, ::T1, ::Bool) where T1
@ EnzymeCore ~/.julia/packages/EnzymeCore/XBDTI/src/EnzymeCore.jl:64
Stacktrace:
[1] top-level scope
@ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:56
Some type information was truncated. Use `show(err)` to see complete types. what is the best way to approach the issue? |
The example scripts has some bugs, for instance the shape of using Flux, Random, Enzyme, GraphNeuralNetworks
loss(model, x) = sum(model(g, x))
model = GraphConv(2 => 5)
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)
grads_zygote = Flux.gradient(loss, model, x)
dmodel = Flux.fmap(model) do x
x isa Array ? zero(x) : x
end
dx = zero(x)
Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Duplicated(x, dx)) Enzyme throws an error here as well. The fundamental blocks of message passing should be tested, operations would be tested, i.e. the operations defined or used in |
Actually the enzyme rules for gather and scatter are already in NNlib: |
Thank you for the code @CarloLucibello. On the positive side, I don't experience using Flux, Random, Enzyme, GraphNeuralNetworks
using NNlib, EnzymeCore
rng = Random.default_rng()
loss(model, x) = sum(model(g, x))
model = GraphConv(2 => 5)
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)
dmodel = Flux.fmap(model) do x
x isa Array ? zero(x) : x
end
dx = zero(x)
Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Duplicated(x, dx))
ERROR: StackOverflowError:
Stacktrace:
[1] getproperty
@ ./Base.jl:32 [inlined]
[2] unwrap_unionall
@ ./essentials.jl:379 [inlined]
[3] fieldnames
@ ./reflection.jl:169 [inlined]
[4] augmented_julia_fieldnames_8170wrap
@ ./reflection.jl:0
[5] macro expansion
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
[6] enzyme_call
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
[7] AugmentedForwardThunk
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009 [inlined]
[8] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(fieldnames), df::Nothing, primal_1::Type{…}, shadow_1_1::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/rules/jitrules.jl:179
--- the last 6 lines are repeated 4611 more times ---
[27675] check_num_nodes
@ ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/utils.jl:2
[27676] GraphConv
@ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:306 [inlined] |
Maybe let me paste here whole stacktrace. Stacktrace:
[1] getproperty
@ ./Base.jl:32 [inlined]
[2] unwrap_unionall
@ ./essentials.jl:379 [inlined]
[3] fieldnames
@ ./reflection.jl:169 [inlined]
[4] augmented_julia_fieldnames_8170wrap
@ ./reflection.jl:0
[5] macro expansion
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
[6] enzyme_call
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
[7] AugmentedForwardThunk
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009 [inlined]
[8] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(fieldnames), df::Nothing, primal_1::Type{…}, shadow_1_1::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/rules/jitrules.jl:179
--- the last 6 lines are repeated 4611 more times ---
[27675] check_num_nodes
@ ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/utils.jl:2
[27676] GraphConv
@ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:306 [inlined]
[27677] GraphConv
@ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:0 [inlined]
[27678] augmented_julia_GraphConv_6402_inner_1wrap
@ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:0
[27679] macro expansion
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
[27680] enzyme_call
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
[27681] AugmentedForwardThunk
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009 [inlined]
[27682] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::GraphConv{…}, df::GraphConv{…}, primal_1::GNNGraph{…}, shadow_1_1::Nothing, primal_2::Matrix{…}, shadow_2_1::Matrix{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/rules/jitrules.jl:179
[27683] loss
@ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:40 [inlined]
[27684] loss
@ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:0 [inlined]
[27685] augmented_julia_loss_8465_inner_1wrap
@ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:0
[27686] macro expansion
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
[27687] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056
[27688] (::Enzyme.Compiler.AugmentedForwardThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009
[27689] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(loss)}, ::Type{Active}, ::Duplicated{GraphConv{…}}, ::Vararg{Any})
@ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:198
[27690] autodiff(::ReverseMode{false, FFIABI}, ::typeof(loss), ::Type, ::Duplicated{GraphConv{…}}, ::Vararg{Any})
@ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:224
Some type information was truncated. Use `show(err)` to see complete types. |
Interestingly, this is first thing in the stacktrace causing the issue. function (l::GraphConv)(g::AbstractGNNGraph, x)
check_num_nodes(g, x)
xj, xi = expand_srcdst(g, x)
m = propagate(copy_xj, g, l.aggr, xj = xj)
x = l.σ.(l.weight1 * xi .+ l.weight2 * m .+ l.bias)
return x
end
function check_num_nodes(g::GNNGraph, x::AbstractArray)
@assert g.num_nodes==size(x, ndims(x)) "Got $(size(x, ndims(x))) as last dimension size instead of num_nodes=$(g.num_nodes)"
return true
end |
Also linking some related issues/PRs for future testing purposes. |
Let's focus on f(x) = sum(propagate(copy_xj, g, +, xj = x))
dx = zero(x)
Enzyme.autodiff(Reverse, loss, Active, Duplicated(x, dx)) |
Here we will keep track of compatibility with Enzyme for taking gradients.
First think is to collect a few examples to run.
The text was updated successfully, but these errors were encountered: