diff --git a/ext/ReactantArrayInterfaceExt.jl b/ext/ReactantArrayInterfaceExt.jl index f23b4b3a3..522d9c78a 100644 --- a/ext/ReactantArrayInterfaceExt.jl +++ b/ext/ReactantArrayInterfaceExt.jl @@ -2,20 +2,21 @@ module ReactantArrayInterfaceExt using ArrayInterface: ArrayInterface using Reactant: - Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray, Ops + Reactant, + RArray, + ConcreteRArray, + ConcreteRNumber, + TracedRNumber, + TracedRArray, + AnyTracedRArray, + Ops ArrayInterface.can_setindex(::Type{<:RArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false -function ArrayInterface.aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where {T} - x_c = ConcreteRArray(zeros(T, size(x))) - x_c .= x - return x_c -end - -ArrayInterface.aos_to_soa(x::TracedRArray) = x -function ArrayInterface.aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T} - return Ops.reshape(vcat(x...), size(x)...) +for aType in + (AbstractArray{<:ConcreteRNumber}, AbstractArray{<:TracedRNumber}, AnyTracedRArray) + @eval ArrayInterface.aos_to_soa(x::$aType) = Reactant.aos_to_soa(x) end end diff --git a/src/Overlay.jl b/src/Overlay.jl index 434507f0f..f26ba2afa 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -127,10 +127,11 @@ for (cT, aT, bT) in ( @reactant_overlay @noinline function LinearAlgebra.mul!( C::$cT, A::$aT, B::$bT, α::Number, β::Number ) + A, B = aos_to_soa(A), aos_to_soa(B) if use_overlayed_version((C, A, B)) TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) else - LinearAlgebra._mul!(C, A, B, α, β) + LinearAlgebra.mul!(C, A, B, α, β) end return C end diff --git a/src/Reactant.jl b/src/Reactant.jl index 99e5e846e..8cd761c50 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -172,6 +172,22 @@ unwrapped_eltype(::RArray{T,N}) where {T,N} = T unwrapped_eltype(::AbstractArray{T,N}) where {T,N} = unwrapped_eltype(T) unwrapped_eltype(::AnyTracedRArray{T,N}) where {T,N} = T +aos_to_soa(x::AbstractArray) = x +aos_to_soa(x::AnyTracedRArray) = x +function aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where {T} + x_c = ConcreteRArray(zeros(T, size(x))) + x_c .= x + return x_c +end +function aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T} + for i in eachindex(x) + if !isassigned(x, i) + x[i] = TracedUtils.promote_to(TracedRNumber{T}, 0) + end + end + return Ops.reshape(vcat(x...), size(x)...) +end + include("Ops.jl") include("TracedUtils.jl") diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 7b491f4b7..cd7f8623a 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -449,6 +449,9 @@ end new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A)) new_traced_value(::TracedRNumber{T}) where {T} = TracedRNumber{T}((), nothing) +function broadcast_to_size(arg::AbstractArray{<:TracedRNumber}, rsize) + return broadcast_to_size(reshape(Ops.vcat(arg...), size(arg)...), rsize) +end broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize) function broadcast_to_size(arg::Base.RefValue, rsize)