diff --git a/Project.toml b/Project.toml index 4d5a2273b5..bcc2ca8a05 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,7 @@ Adapt = "3, 4" CUDA = "4, 5" ChainRulesCore = "1.12" Compat = "4.10.0" -Enzyme = "0.11" +Enzyme = "0.12.4" FiniteDifferences = "0.12" Functors = "0.4" MLUtils = "0.4" diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 36212bb10f..bef48c4da0 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -120,6 +120,7 @@ end (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), + (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), ] for (model, x, name) in models_xs @@ -164,7 +165,7 @@ end device = Flux.get_device() models_xs = [ - (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), + # Pending https://github.com/FluxML/NNlib.jl/issues/565 (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), ]