From 5c0275f5adfb24d5d37fb954b1abf90b8904385d Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Sat, 21 Sep 2024 13:03:21 -0400 Subject: [PATCH 1/4] fix some cases of gradient/jacobian with StaticArrays --- ext/EnzymeStaticArraysExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index bcaa3ec6cb..e29c01cfad 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -3,7 +3,9 @@ module EnzymeStaticArraysExt using StaticArrays using Enzyme -@inline Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape) = reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) +@inline function Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape) + reshape(reduce(hcat, map(vec, rows)), Size(inshape..., outshape...)) +end @inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L} ntuple(Val(L)) do i From 10d22460ff35cc3631539570e14fdec4e6d09d0b Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Fri, 27 Sep 2024 18:48:50 -0400 Subject: [PATCH 2/4] add tests --- ext/EnzymeStaticArraysExt.jl | 7 ++++++ test/runtests.jl | 45 ++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index e29c01cfad..0f7cc56cf0 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -3,6 +3,13 @@ module EnzymeStaticArraysExt using StaticArrays using Enzyme +#TODO: it would be better if we could always return SArray from gradient directly, +# to retain shape information. for now we are at least able to convert +@inline function Base.convert(::Type{SArray}, tpa::Enzyme.TupleArray{T,S,L,N}) where {T,S,L,N} + SArray{Tuple{S...},T,N,L}(tpa.data) +end +@inline Base.convert(::Type{StaticArray}, tpa::Enzyme.TupleArray) = convert(SArray, tpa) + @inline function Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape) reshape(reduce(hcat, map(vec, rows)), Size(inshape..., outshape...)) end diff --git a/test/runtests.jl b/test/runtests.jl index d499febd77..ff863b04fa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2854,6 +2854,51 @@ end @test dx[1] ≈ 0 @test dx[2] ≈ 30 @test dx[3] ≈ 0 + + f0 = x -> sum(2*x) + f1 = x -> @SVector Float64[x[2], 2*x[2]] + f2 = x -> @SMatrix Float64[x[2] x[1]; 2*x[2] 2*x[1]] + + x = @SVector Float64[1, 2] + + dx = gradient(Forward, f0, x)[1] + @test dx isa Enzyme.TupleArray + @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works + @test gradient(Forward, f1, x)[1] isa SMatrix + @test gradient(Forward, f1, x)[1] == [0 1.0; 0 2.0] + @test jacobian(Forward, f2, x)[1] isa SArray + @test jacobian(Forward, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) + + x = @SMatrix Float64[1 2; 3 4] + + dx = gradient(Forward, f0, x)[1] + @test dx isa Enzyme.TupleArray + @test convert(SArray, dx) == fill(2.0, (2,2)) + @test gradient(Forward, f1, x)[1] isa SArray + @test gradient(Forward, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) + @test jacobian(Forward, f2, x)[1] isa SArray + @test jacobian(Forward, f2, x)[1] == reshape( + Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2), + ) + + x = @SVector Float64[1, 2] + + dx = gradient(Reverse, f0, x)[1] + @test dx isa SVector + @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works + @test_broken gradient(Reverse, f1, x)[1] isa SMatrix + @test_broken gradient(Reverse, f1, x)[1] == [0 1.0; 0 2.0] + @test_broken jacobian(Reverse, f2, x)[1] isa SArray + @test_broken jacobian(Reverse, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) + + x = @SMatrix Float64[1 2; 3 4] + + @test_broken gradient(Reverse, f1, x)[1] isa SArray + @test_broken gradient(Reverse, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) + @test_broken jacobian(Reverse, f2, x)[1] isa SArray + @test_broken jacobian(Reverse, f2, x)[1] == reshape( + Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2), + ) end function unstable_fun(A0) From 2aeef55bbdfed166d51b3d19f291fa43bf4d3dc7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 21:06:12 -0500 Subject: [PATCH 3/4] Update EnzymeStaticArraysExt.jl --- ext/EnzymeStaticArraysExt.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index 0f7cc56cf0..8e35c57906 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -3,8 +3,6 @@ module EnzymeStaticArraysExt using StaticArrays using Enzyme -#TODO: it would be better if we could always return SArray from gradient directly, -# to retain shape information. for now we are at least able to convert @inline function Base.convert(::Type{SArray}, tpa::Enzyme.TupleArray{T,S,L,N}) where {T,S,L,N} SArray{Tuple{S...},T,N,L}(tpa.data) end From 5349faab80e0a3425b76a25bec0c8eee0accb236 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Sat, 28 Sep 2024 10:31:02 -0400 Subject: [PATCH 4/4] jacobian is exported, wtf? --- test/runtests.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 1759011a97..49448dd8b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2866,8 +2866,8 @@ end @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works @test gradient(Forward, f1, x)[1] isa SMatrix @test gradient(Forward, f1, x)[1] == [0 1.0; 0 2.0] - @test jacobian(Forward, f2, x)[1] isa SArray - @test jacobian(Forward, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) + @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray + @test Enzyme.jacobian(Forward, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) x = @SMatrix Float64[1 2; 3 4] @@ -2876,8 +2876,8 @@ end @test convert(SArray, dx) == fill(2.0, (2,2)) @test gradient(Forward, f1, x)[1] isa SArray @test gradient(Forward, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) - @test jacobian(Forward, f2, x)[1] isa SArray - @test jacobian(Forward, f2, x)[1] == reshape( + @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray + @test Enzyme.jacobian(Forward, f2, x)[1] == reshape( Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2), ) @@ -2888,15 +2888,15 @@ end @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works @test_broken gradient(Reverse, f1, x)[1] isa SMatrix @test_broken gradient(Reverse, f1, x)[1] == [0 1.0; 0 2.0] - @test_broken jacobian(Reverse, f2, x)[1] isa SArray - @test_broken jacobian(Reverse, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) x = @SMatrix Float64[1 2; 3 4] @test_broken gradient(Reverse, f1, x)[1] isa SArray @test_broken gradient(Reverse, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) - @test_broken jacobian(Reverse, f2, x)[1] isa SArray - @test_broken jacobian(Reverse, f2, x)[1] == reshape( + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] == reshape( Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2), ) end @@ -4146,4 +4146,4 @@ include("ext/logexpfunctions.jl") @testset "BFloat16s ext" begin include("ext/bfloat16s.jl") -end \ No newline at end of file +end