Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme.jl compatibility #389

Open
CarloLucibello opened this issue Mar 1, 2024 · 8 comments
Open

Enzyme.jl compatibility #389

CarloLucibello opened this issue Mar 1, 2024 · 8 comments
Assignees

Comments

@CarloLucibello
Copy link
Member

Here we will keep track of compatibility with Enzyme for taking gradients.
First think is to collect a few examples to run.

@askorupka askorupka self-assigned this Mar 4, 2024
@askorupka
Copy link
Collaborator

hi @CarloLucibello as per our discussion I've set up a working example for Flux and then used this example to extend to GraphNeuralNetworks.
here is my code:

using Flux, Random, Enzyme, GraphNeuralNetworks
rng = Random.default_rng()

loss(model, x) = sum(model(g, g.x))

model = GNNChain(GCNConv(2=>5), 
                    BatchNorm(5), 
                    x -> relu.(x), 
                    Dense(5, 4))
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)

g.ndata.x = x

grads_zygote = Flux.gradient(model->loss(model, x), model)[1]

dx = grads_enzyme = Flux.fmap(model) do x
    x isa Array ? zero(x) : x
end

Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, grads_enzyme), Duplicated(x, dx))

the last line results in error message indicating lack of Enzyme's Duplicated compatibility with GNNChain (contrary to Flux's Chain).

ERROR: MethodError: no method matching Duplicated(::Int64, ::GNNChain{Tuple{GCNConv{…}, BatchNorm{…}, var"#47#48", Dense{…}}})

Closest candidates are:
  Duplicated(::T1, ::T1) where T1
   @ EnzymeCore ~/.julia/packages/EnzymeCore/XBDTI/src/EnzymeCore.jl:64
  Duplicated(::T1, ::T1, ::Bool) where T1
   @ EnzymeCore ~/.julia/packages/EnzymeCore/XBDTI/src/EnzymeCore.jl:64

Stacktrace:
 [1] top-level scope
   @ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:56
Some type information was truncated. Use `show(err)` to see complete types.

what is the best way to approach the issue?
I was looking into EnzymeCore/ line 65 and I think the method should be extended from T1 to GNNChain but I can't find T1 definition anywhere. any ideas?

@CarloLucibello
Copy link
Member Author

The example scripts has some bugs, for instance the shape of dx was not correct. Here the corrected script with also a simplified model

using Flux, Random, Enzyme, GraphNeuralNetworks

loss(model, x) = sum(model(g, x))

model = GraphConv(2 => 5)
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)

grads_zygote = Flux.gradient(loss, model, x)

dmodel = Flux.fmap(model) do x
    x isa Array ? zero(x) : x
end

dx = zero(x)

Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Duplicated(x, dx))

Enzyme throws an error here as well. The fundamental blocks of message passing should be tested, operations would be tested, i.e. the operations defined or used in
https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/src/msgpass.jl
Most probably we will need a rule for gather and scatter, so I would test those operations first.

@CarloLucibello
Copy link
Member Author

Actually the enzyme rules for gather and scatter are already in NNlib:
https://github.com/FluxML/NNlib.jl/blob/master/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl

@askorupka
Copy link
Collaborator

Thank you for the code @CarloLucibello.
I've tried it after loading NNlibEnzymeCore extension but I'm not sure if the Enzyme.autodiff function call in the above example snippet is correct as it yields StackOverFlow error.
This usually happens "when the call stack exceeds its maximum size, typically due to infinite recursion" (source). I'm wondering if that means that Enzyme.autodiff call is recursive in that example?

On the positive side, I don't experience MethodError anymore.

using Flux, Random, Enzyme, GraphNeuralNetworks
using NNlib, EnzymeCore
rng = Random.default_rng()

loss(model, x) = sum(model(g, x))

model = GraphConv(2 => 5)
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)

dmodel = Flux.fmap(model) do x
    x isa Array ? zero(x) : x
end

dx = zero(x)

Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Duplicated(x, dx))

ERROR: StackOverflowError:
Stacktrace:
     [1] getproperty
       @ ./Base.jl:32 [inlined]
     [2] unwrap_unionall
       @ ./essentials.jl:379 [inlined]
     [3] fieldnames
       @ ./reflection.jl:169 [inlined]
     [4] augmented_julia_fieldnames_8170wrap
       @ ./reflection.jl:0
     [5] macro expansion
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
     [6] enzyme_call
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
     [7] AugmentedForwardThunk
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009 [inlined]
     [8] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(fieldnames), df::Nothing, primal_1::Type{…}, shadow_1_1::Nothing)
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/rules/jitrules.jl:179
--- the last 6 lines are repeated 4611 more times ---
 [27675] check_num_nodes
       @ ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/utils.jl:2
 [27676] GraphConv
       @ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:306 [inlined]

@askorupka
Copy link
Collaborator

Maybe let me paste here whole stacktrace.

Stacktrace:
     [1] getproperty
       @ ./Base.jl:32 [inlined]
     [2] unwrap_unionall
       @ ./essentials.jl:379 [inlined]
     [3] fieldnames
       @ ./reflection.jl:169 [inlined]
     [4] augmented_julia_fieldnames_8170wrap
       @ ./reflection.jl:0
     [5] macro expansion
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
     [6] enzyme_call
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
     [7] AugmentedForwardThunk
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009 [inlined]
     [8] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(fieldnames), df::Nothing, primal_1::Type{…}, shadow_1_1::Nothing)
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/rules/jitrules.jl:179
--- the last 6 lines are repeated 4611 more times ---
 [27675] check_num_nodes
       @ ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/utils.jl:2
 [27676] GraphConv
       @ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:306 [inlined]
 [27677] GraphConv
       @ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:0 [inlined]
 [27678] augmented_julia_GraphConv_6402_inner_1wrap
       @ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:0
 [27679] macro expansion
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
 [27680] enzyme_call
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
 [27681] AugmentedForwardThunk
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009 [inlined]
 [27682] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::GraphConv{…}, df::GraphConv{…}, primal_1::GNNGraph{…}, shadow_1_1::Nothing, primal_2::Matrix{…}, shadow_2_1::Matrix{…})
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/rules/jitrules.jl:179
 [27683] loss
       @ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:40 [inlined]
 [27684] loss
       @ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:0 [inlined]
 [27685] augmented_julia_loss_8465_inner_1wrap
       @ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:0
 [27686] macro expansion
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
 [27687] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…})
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056
 [27688] (::Enzyme.Compiler.AugmentedForwardThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009
 [27689] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(loss)}, ::Type{Active}, ::Duplicated{GraphConv{…}}, ::Vararg{Any})
       @ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:198
 [27690] autodiff(::ReverseMode{false, FFIABI}, ::typeof(loss), ::Type, ::Duplicated{GraphConv{…}}, ::Vararg{Any})
       @ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:224
Some type information was truncated. Use `show(err)` to see complete types.

@askorupka
Copy link
Collaborator

askorupka commented Mar 17, 2024

Interestingly, this is first thing in the stacktrace causing the issue.

function (l::GraphConv)(g::AbstractGNNGraph, x)
    check_num_nodes(g, x)
    xj, xi = expand_srcdst(g, x)
    m = propagate(copy_xj, g, l.aggr, xj = xj)
    x = l.σ.(l.weight1 * xi .+ l.weight2 * m .+ l.bias)
    return x
end

function check_num_nodes(g::GNNGraph, x::AbstractArray)
    @assert g.num_nodes==size(x, ndims(x)) "Got $(size(x, ndims(x))) as last dimension size instead of num_nodes=$(g.num_nodes)"
    return true
end

@askorupka
Copy link
Collaborator

askorupka commented Mar 17, 2024

Also linking some related issues/PRs for future testing purposes.
Flux.jl PR #2392
EnzymeAD issue #805

@CarloLucibello
Copy link
Member Author

Let's focus on propagate, e.g.

f(x)  = sum(propagate(copy_xj, g, +, xj = x))
dx = zero(x)
Enzyme.autodiff(Reverse, loss, Active, Duplicated(x, dx))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants