Skip to content

Commit

Permalink
Forward complex matrix multiplication to components (#294)
Browse files Browse the repository at this point in the history
* Forward complex matrix multiplication to components

* mul for vectors

* Convert tabs to spaces

* Remove extra end
  • Loading branch information
jishnub authored Jul 26, 2024
1 parent 58aba83 commit 9ca8264
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 3 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

Expand All @@ -22,6 +24,7 @@ StructArraysAdaptExt = "Adapt"
StructArraysGPUArraysCoreExt = "GPUArraysCore"
StructArraysSparseArraysExt = "SparseArrays"
StructArraysStaticArraysExt = "StaticArrays"
StructArraysLinearAlgebraExt = "LinearAlgebra"

[compat]
Adapt = "3.4, 4"
Expand Down
25 changes: 25 additions & 0 deletions ext/StructArraysLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module StructArraysLinearAlgebraExt

using StructArrays
using LinearAlgebra
import LinearAlgebra: mul!

const StructMatrixC{T, A<:AbstractMatrix{T}} = StructArrays.StructMatrix{Complex{T}, @NamedTuple{re::A, im::A}}
const StructVectorC{T, A<:AbstractVector{T}} = StructArrays.StructVector{Complex{T}, @NamedTuple{re::A, im::A}}

function _mul!(C, A, B, alpha, beta)
mul!(C.re, A.re, B.re, alpha, beta)
mul!(C.re, A.im, B.im, -alpha, oneunit(beta))
mul!(C.im, A.re, B.im, alpha, beta)
mul!(C.im, A.im, B.re, alpha, oneunit(beta))
C
end

function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end
function mul!(C::StructVectorC, A::StructMatrixC, B::StructVectorC, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end

end
1 change: 1 addition & 0 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ end
include("../ext/StructArraysGPUArraysCoreExt.jl")
include("../ext/StructArraysSparseArraysExt.jl")
include("../ext/StructArraysStaticArraysExt.jl")
include("../ext/StructArraysLinearAlgebraExt.jl")
end

end # module
1 change: 1 addition & 0 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ _structarray(args::Tuple, ::Tuple) = _structarray(args, nothing)
_structarray(args::NTuple{N, Any}, names::NTuple{N, Symbol}) where {N} = StructArray(NamedTuple{names}(args))

const StructVector{T, C<:Tup, I} = StructArray{T, 1, C, I}
const StructMatrix{T, C<:Tup, I} = StructArray{T, 2, C, I}
StructVector{T}(args...; kwargs...) where {T} = StructArray{T}(args...; kwargs...)
StructVector(args...; kwargs...) = StructArray(args...; kwargs...)

Expand Down
23 changes: 20 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ end
# The following code defines `MyArray1/2/3` with different `BroadcastStyle`s.
# 1. `MyArray1` and `MyArray1` have `similar` defined.
# We use them to simulate `BroadcastStyle` overloading `Base.copyto!`.
# 2. `MyArray3` has no `similar` defined.
# 2. `MyArray3` has no `similar` defined.
# We use it to simulate `BroadcastStyle` overloading `Base.copy`.
# 3. Their resolved style could be summaryized as (`-` means conflict)
# | MyArray1 | MyArray2 | MyArray3 | Array
Expand Down Expand Up @@ -1302,7 +1302,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
f(s) = s .+= 1
f(s)
@test (@allocated f(s)) == 0

# issue #185
A = StructArray(randn(ComplexF64, 3, 3))
B = randn(ComplexF64, 3, 3)
Expand All @@ -1321,7 +1321,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS

@testset "ambiguity check" begin
test_set = Any[StructArray([1;2+im]),
1:2,
1:2,
(1,2),
StructArray(@SArray [1;1+2im]),
(@SArray [1 2]),
Expand Down Expand Up @@ -1550,6 +1550,23 @@ end
@test Base.IteratorSize(S) == Base.IsInfinite()
end

@testset "LinearAlgebra" begin
@testset "matrix * matrix" begin
A = StructArray{ComplexF64}((rand(10,10), rand(10,10)))
B = StructArray{ComplexF64}((rand(size(A)...), rand(size(A)...)))
MA, MB = Matrix(A), Matrix(B)
@test A * B MA * MB
@test mul!(ones(ComplexF64,size(A)), A, B, 2.0, 3.0) 2 * A * B .+ 3
end
@testset "matrix * vector" begin
A = StructArray{ComplexF64}((rand(10,10), rand(10,10)))
v = StructArray{ComplexF64}((rand(size(A,2)), rand(size(A,2))))
MA, Mv = Matrix(A), Vector(v)
@test A * v MA * Mv
@test mul!(ones(ComplexF64,size(v)), A, v, 2.0, 3.0) 2 * A * v .+ 3
end
end

@testset "project quality" begin
Aqua.test_all(StructArrays, ambiguities=(; broken=true))
end

0 comments on commit 9ca8264

Please sign in to comment.