Skip to content

Commit

Permalink
Fix mul overload (#440)
Browse files Browse the repository at this point in the history
* Fix mul overload

* fix

* fix

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix: handle aos for mul (#441)

* fix: handle aos for mul

* Update ext/ReactantArrayInterfaceExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* revert: incorrect aos_to_soa for C

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update Reactant.jl

* Update ReactantArrayInterfaceExt.jl

* Update ext/ReactantArrayInterfaceExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Avik Pal <[email protected]>
  • Loading branch information
3 people authored Dec 31, 2024
1 parent 885264e commit 85d8ba4
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 11 deletions.
21 changes: 11 additions & 10 deletions ext/ReactantArrayInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@ module ReactantArrayInterfaceExt

using ArrayInterface: ArrayInterface
using Reactant:
Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray, Ops
Reactant,
RArray,
ConcreteRArray,
ConcreteRNumber,
TracedRNumber,
TracedRArray,
AnyTracedRArray,
Ops

ArrayInterface.can_setindex(::Type{<:RArray}) = false
ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false

function ArrayInterface.aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where {T}
x_c = ConcreteRArray(zeros(T, size(x)))
x_c .= x
return x_c
end

ArrayInterface.aos_to_soa(x::TracedRArray) = x
function ArrayInterface.aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T}
return Ops.reshape(vcat(x...), size(x)...)
for aType in
(AbstractArray{<:ConcreteRNumber}, AbstractArray{<:TracedRNumber}, AnyTracedRArray)
@eval ArrayInterface.aos_to_soa(x::$aType) = Reactant.aos_to_soa(x)
end

end
3 changes: 2 additions & 1 deletion src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,11 @@ for (cT, aT, bT) in (
@reactant_overlay @noinline function LinearAlgebra.mul!(
C::$cT, A::$aT, B::$bT, α::Number, β::Number
)
A, B = aos_to_soa(A), aos_to_soa(B)
if use_overlayed_version((C, A, B))
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
else
LinearAlgebra._mul!(C, A, B, α, β)
LinearAlgebra.mul!(C, A, B, α, β)
end
return C
end
Expand Down
16 changes: 16 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,22 @@ unwrapped_eltype(::RArray{T,N}) where {T,N} = T
unwrapped_eltype(::AbstractArray{T,N}) where {T,N} = unwrapped_eltype(T)
unwrapped_eltype(::AnyTracedRArray{T,N}) where {T,N} = T

aos_to_soa(x::AbstractArray) = x
aos_to_soa(x::AnyTracedRArray) = x
function aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where {T}
x_c = ConcreteRArray(zeros(T, size(x)))
x_c .= x
return x_c
end
function aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T}
for i in eachindex(x)
if !isassigned(x, i)
x[i] = TracedUtils.promote_to(TracedRNumber{T}, 0)
end
end
return Ops.reshape(vcat(x...), size(x)...)
end

include("Ops.jl")
include("TracedUtils.jl")

Expand Down
3 changes: 3 additions & 0 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@ end
new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A))
new_traced_value(::TracedRNumber{T}) where {T} = TracedRNumber{T}((), nothing)

function broadcast_to_size(arg::AbstractArray{<:TracedRNumber}, rsize)
return broadcast_to_size(reshape(Ops.vcat(arg...), size(arg)...), rsize)
end
broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize)

function broadcast_to_size(arg::Base.RefValue, rsize)
Expand Down

0 comments on commit 85d8ba4

Please sign in to comment.