diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index bcaa3ec6cb..e29c01cfad 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -3,7 +3,9 @@ module EnzymeStaticArraysExt using StaticArrays using Enzyme -@inline Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape) = reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) +@inline function Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape) + reshape(reduce(hcat, map(vec, rows)), Size(inshape..., outshape...)) +end @inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L} ntuple(Val(L)) do i