From 962e1913a4cee7f1249f5c4b85dcf98e92c7240a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 7 May 2024 16:43:09 -0400 Subject: [PATCH 1/5] Enzyme: bump version and mark models as working [test] --- Project.toml | 2 +- test/ext_enzyme/enzyme.jl | 31 ++----------------------------- 2 files changed, 3 insertions(+), 30 deletions(-) diff --git a/Project.toml b/Project.toml index 4d5a2273b5..eb3b34801e 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,7 @@ Adapt = "3, 4" CUDA = "4, 5" ChainRulesCore = "1.12" Compat = "4.10.0" -Enzyme = "0.11" +Enzyme = "0.12" FiniteDifferences = "0.12" Functors = "0.4" MLUtils = "0.4" diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 36212bb10f..8665e18376 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -120,6 +120,8 @@ end (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), + (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), + (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), ] for (model, x, name) in models_xs @@ -154,32 +156,3 @@ end end end end - -@testset "Broken Models" begin - function loss(model, x) - Flux.reset!(model) - sum(model(x)) - end - - device = Flux.get_device() - - models_xs = [ - (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), - (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), - ] - - for (model, x, name) in models_xs - @testset "check grad $name" begin - println("testing $name") - broken = false - try - test_enzyme_grad(loss, model, x) - catch e - println(e) - broken = true - end - @test broken - end - end -end - From a6d8b31daf5255df2905a5f853a8ec3b79283f86 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 9 May 2024 13:15:35 -0700 Subject: [PATCH 2/5] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index eb3b34801e..a7fa7500fa 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,7 @@ Adapt = "3, 4" CUDA = "4, 5" ChainRulesCore = "1.12" Compat = "4.10.0" -Enzyme = "0.12" +Enzyme = "0.12.3" FiniteDifferences = "0.12" Functors = "0.4" MLUtils = "0.4" From cde7acf7599664a10c2e5d89316729e6410d92a9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 10 May 2024 17:40:35 -0700 Subject: [PATCH 3/5] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a7fa7500fa..bcc2ca8a05 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,7 @@ Adapt = "3, 4" CUDA = "4, 5" ChainRulesCore = "1.12" Compat = "4.10.0" -Enzyme = "0.12.3" +Enzyme = "0.12.4" FiniteDifferences = "0.12" Functors = "0.4" MLUtils = "0.4" From a2b49ee5ad5aac0781e715a641c274cd47acea3d Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 10 May 2024 18:56:17 -0700 Subject: [PATCH 4/5] Update enzyme.jl From 504ac8545851fe5f6c0c140a18bf37ed3e3a53c9 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 11 May 2024 00:26:00 -0400 Subject: [PATCH 5/5] Mark transpose as not supported --- test/ext_enzyme/enzyme.jl | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 8665e18376..bef48c4da0 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -121,7 +121,6 @@ end (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), - (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), ] for (model, x, name) in models_xs @@ -156,3 +155,32 @@ end end end end + +@testset "Broken Models" begin + function loss(model, x) + Flux.reset!(model) + sum(model(x)) + end + + device = Flux.get_device() + + models_xs = [ + # Pending https://github.com/FluxML/NNlib.jl/issues/565 + (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), + ] + + for (model, x, name) in models_xs + @testset "check grad $name" begin + println("testing $name") + broken = false + try + test_enzyme_grad(loss, model, x) + catch e + println(e) + broken = true + end + @test broken + end + end +end +