From 873ef6cad32828aaa436875a80458bb6229f49b2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 30 Dec 2024 21:06:36 -0500 Subject: [PATCH 1/8] Fix mul overload --- src/Overlay.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index 0b5844464..bc8c62bbe 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -130,14 +130,18 @@ for (cT, aT, bT) in ( if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) else - LinearAlgebra._mul!(C, A, B, α, β) + LinearAlgebra.mul!(C, A, B, α, β) end return C end # Needed mostly for 1.10 where 3-arg mul is often specialized @reactant_overlay @noinline function LinearAlgebra.mul!(C::$cT, A::$aT, B::$bT) - call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false) + if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) + TracedLinearAlgebra.overloaded_mul!(C, A, B, true, false) + else + LinearAlgebra.mul!(C, A, B) + end return C end end From 97d184a01fc72a33ce7253d1ba57a62e081b5c3f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 30 Dec 2024 21:13:41 -0500 Subject: [PATCH 2/8] fix --- src/TracedUtils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 7b491f4b7..eb9519ef2 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -449,6 +449,7 @@ end new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A)) new_traced_value(::TracedRNumber{T}) where {T} = TracedRNumber{T}((), nothing) +broadcast_to_size(arg::AbstractArray{<:TracedRNumber}, rsize) = broadcast_to_size(reshape(Ops.vcat(arg...), size(arg)...), rsize) broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize) function broadcast_to_size(arg::Base.RefValue, rsize) From 411c4783d011eeb2e30a7d0d1e890f8dc8744871 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 30 Dec 2024 21:18:51 -0500 Subject: [PATCH 3/8] fix --- src/stdlibs/LinearAlgebra.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index aa56c7b92..c03df03c0 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -163,6 +163,26 @@ function TracedUtils.set_mlir_data!( return x end +function overloaded_mul!( + @nospecialize(C::TracedRArray), + @nospecialize(A::AbstractArray{<:TracedRNumber}) + @nospecialize(B::AbstractArray), + α::Number=true, + β::Number=false, +) where {T} + overloaded_mul!(C, Ops.reshape(vcat(A...), size(A)...), B, α, β) +end + +function overloaded_mul!( + @nospecialize(C::TracedRArray), + @nospecialize(A::AbstractArray), + @nospecialize(B::AbstractArray{<:TracedRNumber}) + α::Number=true, + β::Number=false, +) where {T} + overloaded_mul!(C, A, Ops.reshape(vcat(B...), size(B)...), α, β) +end + # Core functions function overloaded_mul!( @nospecialize(C::TracedRArray{T,1}), From 39e07cd0a52dadff401a7bbac5ed88c6caf7318a Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 30 Dec 2024 21:41:15 -0500 Subject: [PATCH 4/8] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/TracedUtils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index eb9519ef2..cd7f8623a 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -449,7 +449,9 @@ end new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A)) new_traced_value(::TracedRNumber{T}) where {T} = TracedRNumber{T}((), nothing) -broadcast_to_size(arg::AbstractArray{<:TracedRNumber}, rsize) = broadcast_to_size(reshape(Ops.vcat(arg...), size(arg)...), rsize) +function broadcast_to_size(arg::AbstractArray{<:TracedRNumber}, rsize) + return broadcast_to_size(reshape(Ops.vcat(arg...), size(arg)...), rsize) +end broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize) function broadcast_to_size(arg::Base.RefValue, rsize) From 3b2c41002320b5f455f0c6190c6848119d325654 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 31 Dec 2024 12:17:05 -0500 Subject: [PATCH 5/8] fix: handle aos for mul (#441) * fix: handle aos for mul * Update ext/ReactantArrayInterfaceExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * revert: incorrect aos_to_soa for C --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/ReactantArrayInterfaceExt.jl | 12 +++--------- src/Overlay.jl | 7 ++----- src/Reactant.jl | 16 ++++++++++++++++ src/stdlibs/LinearAlgebra.jl | 20 -------------------- 4 files changed, 21 insertions(+), 34 deletions(-) diff --git a/ext/ReactantArrayInterfaceExt.jl b/ext/ReactantArrayInterfaceExt.jl index f23b4b3a3..cb43c436e 100644 --- a/ext/ReactantArrayInterfaceExt.jl +++ b/ext/ReactantArrayInterfaceExt.jl @@ -7,15 +7,9 @@ using Reactant: ArrayInterface.can_setindex(::Type{<:RArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false -function ArrayInterface.aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where {T} - x_c = ConcreteRArray(zeros(T, size(x))) - x_c .= x - return x_c -end - -ArrayInterface.aos_to_soa(x::TracedRArray) = x -function ArrayInterface.aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T} - return Ops.reshape(vcat(x...), size(x)...) +for aType in + (AbstractArray{<:ConcreteRNumber}, AbstractArray{<:TracedRNumber}, TracedRArray) + @eval ArrayInterface.aos_to_soa(x::$aType) = Reactant.aos_to_soa(x) end end diff --git a/src/Overlay.jl b/src/Overlay.jl index bc8c62bbe..d2e1ce3f5 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -127,6 +127,7 @@ for (cT, aT, bT) in ( @reactant_overlay @noinline function LinearAlgebra.mul!( C::$cT, A::$aT, B::$bT, α::Number, β::Number ) + A, B = aos_to_soa(A), aos_to_soa(B) if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) else @@ -137,11 +138,7 @@ for (cT, aT, bT) in ( # Needed mostly for 1.10 where 3-arg mul is often specialized @reactant_overlay @noinline function LinearAlgebra.mul!(C::$cT, A::$aT, B::$bT) - if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) - TracedLinearAlgebra.overloaded_mul!(C, A, B, true, false) - else - LinearAlgebra.mul!(C, A, B) - end + call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false) return C end end diff --git a/src/Reactant.jl b/src/Reactant.jl index d06784c13..addf7089b 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -172,6 +172,22 @@ unwrapped_eltype(::RArray{T,N}) where {T,N} = T unwrapped_eltype(::AbstractArray{T,N}) where {T,N} = unwrapped_eltype(T) unwrapped_eltype(::AnyTracedRArray{T,N}) where {T,N} = T +aos_to_soa(x::AbstractArray) = x +aos_to_soa(x::TracedRArray) = x +function aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where {T} + x_c = ConcreteRArray(zeros(T, size(x))) + x_c .= x + return x_c +end +function aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T} + for i in eachindex(x) + if !isassigned(x, i) + x[i] = TracedUtils.promote_to(TracedRNumber{T}, 0) + end + end + return Ops.reshape(vcat(x...), size(x)...) +end + include("Ops.jl") include("TracedUtils.jl") diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index c03df03c0..aa56c7b92 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -163,26 +163,6 @@ function TracedUtils.set_mlir_data!( return x end -function overloaded_mul!( - @nospecialize(C::TracedRArray), - @nospecialize(A::AbstractArray{<:TracedRNumber}) - @nospecialize(B::AbstractArray), - α::Number=true, - β::Number=false, -) where {T} - overloaded_mul!(C, Ops.reshape(vcat(A...), size(A)...), B, α, β) -end - -function overloaded_mul!( - @nospecialize(C::TracedRArray), - @nospecialize(A::AbstractArray), - @nospecialize(B::AbstractArray{<:TracedRNumber}) - α::Number=true, - β::Number=false, -) where {T} - overloaded_mul!(C, A, Ops.reshape(vcat(B...), size(B)...), α, β) -end - # Core functions function overloaded_mul!( @nospecialize(C::TracedRArray{T,1}), From b1f5325f113af4c36e4ebf247ee0bc3c5d8cb5c4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 31 Dec 2024 15:28:19 -0500 Subject: [PATCH 6/8] Update Reactant.jl --- src/Reactant.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index f51894c11..8cd761c50 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -173,7 +173,7 @@ unwrapped_eltype(::AbstractArray{T,N}) where {T,N} = unwrapped_eltype(T) unwrapped_eltype(::AnyTracedRArray{T,N}) where {T,N} = T aos_to_soa(x::AbstractArray) = x -aos_to_soa(x::TracedRArray) = x +aos_to_soa(x::AnyTracedRArray) = x function aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where {T} x_c = ConcreteRArray(zeros(T, size(x))) x_c .= x From 6628582ea384bd25af91d002cb05bfd9836612de Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 31 Dec 2024 15:49:05 -0500 Subject: [PATCH 7/8] Update ReactantArrayInterfaceExt.jl --- ext/ReactantArrayInterfaceExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/ReactantArrayInterfaceExt.jl b/ext/ReactantArrayInterfaceExt.jl index cb43c436e..e1ad9be25 100644 --- a/ext/ReactantArrayInterfaceExt.jl +++ b/ext/ReactantArrayInterfaceExt.jl @@ -2,13 +2,13 @@ module ReactantArrayInterfaceExt using ArrayInterface: ArrayInterface using Reactant: - Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray, Ops + Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray, AnyTracedRArray, Ops ArrayInterface.can_setindex(::Type{<:RArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false for aType in - (AbstractArray{<:ConcreteRNumber}, AbstractArray{<:TracedRNumber}, TracedRArray) + (AbstractArray{<:ConcreteRNumber}, AbstractArray{<:TracedRNumber}, AnyTracedRArray) @eval ArrayInterface.aos_to_soa(x::$aType) = Reactant.aos_to_soa(x) end From e53b0039d8b9e8654498e33e7d18e614a381669e Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 31 Dec 2024 15:59:15 -0500 Subject: [PATCH 8/8] Update ext/ReactantArrayInterfaceExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/ReactantArrayInterfaceExt.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ext/ReactantArrayInterfaceExt.jl b/ext/ReactantArrayInterfaceExt.jl index e1ad9be25..522d9c78a 100644 --- a/ext/ReactantArrayInterfaceExt.jl +++ b/ext/ReactantArrayInterfaceExt.jl @@ -2,7 +2,14 @@ module ReactantArrayInterfaceExt using ArrayInterface: ArrayInterface using Reactant: - Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray, AnyTracedRArray, Ops + Reactant, + RArray, + ConcreteRArray, + ConcreteRNumber, + TracedRNumber, + TracedRArray, + AnyTracedRArray, + Ops ArrayInterface.can_setindex(::Type{<:RArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false