Skip to content

Commit

Permalink
Fix mat * vec with non-commutative numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Dec 6, 2023
1 parent ce4519d commit 56bfb23
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 12 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "BandedMatrices"
uuid = "aae01518-5342-5314-be14-df237901396f"
version = "1.3"
version = "1.3.1"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand All @@ -24,6 +24,7 @@ GenericLinearAlgebra = "0.3"
InfiniteArrays = "0.12, 0.13"
LinearAlgebra = "1.6"
PrecompileTools = "1"
Quaternions = "0.7"
Random = "1.6"
SparseArrays = "1.6"
Test = "1.6"
Expand All @@ -34,9 +35,10 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c"
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Documenter", "GenericLinearAlgebra", "InfiniteArrays", "Random", "SparseArrays", "Test"]
test = ["Aqua", "Documenter", "GenericLinearAlgebra", "InfiniteArrays", "Random", "SparseArrays", "Test", "Quaternions"]
2 changes: 1 addition & 1 deletion src/BandedMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import ArrayLayouts: MemoryLayout, transposelayout, triangulardata,
triangularlayout, MatLdivVec, hermitianlayout, hermitiandata,
materialize, materialize!, BlasMatMulMatAdd, BlasMatMulVecAdd, BlasMatLmulVec, BlasMatLdivVec,
colsupport, rowsupport, symmetricuplo, MatMulMatAdd, MatMulVecAdd,
sublayout, sub_materialize, _fill_lmul!, _copy_oftype,
sublayout, sub_materialize, _copy_oftype, zero!,
reflector!, reflectorApply!, _copyto!, checkdimensions,
_qr!, _qr, _lu!, _lu, _factorize, AbstractTridiagonalLayout, TridiagonalLayout,
BidiagonalLayout, bidiagonaluplo, diagonaldata, supdiagonaldata, subdiagonaldata, copymutable_oftype_layout, dualadjoint
Expand Down
16 changes: 8 additions & 8 deletions src/generic/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ end
@inline function materialize!(M::MatMulVecAdd{<:AbstractBandedLayout})
checkdimensions(M)
α,A,B,β,C = M.α,M.A,M.B,M.β,M.C
_fill_lmul!(β, C)
_fill_rmul!(C, β)
@inbounds for j = intersect(rowsupport(A), colsupport(B))
for k = colrange(A,j)
C[k] += α*inbands_getindex(A,k,j)*B[j]
C[k] += inbands_getindex(A,k,j) * B[j] * α
end
end
C
Expand All @@ -113,11 +113,11 @@ end
checkdimensions(M)
α,At,B,β,C = M.α,M.A,M.B,M.β,M.C
A = transpose(At)
_fill_lmul!(β, C)
_fill_rmul!(C, β)

@inbounds for j = rowsupport(A)
for k = intersect(colrange(A,j), colsupport(B))
C[j] += α*transpose(inbands_getindex(A,k,j))*B[k]
C[j] += transpose(inbands_getindex(A,k,j)) * B[k] * α
end
end
C
Expand All @@ -127,10 +127,10 @@ end
checkdimensions(M)
α,Ac,B,β,C = M.α,M.A,M.B,M.β,M.C
A = Ac'
_fill_lmul!(β, C)
_fill_rmul!(C, β)
@inbounds for j = rowsupport(A)
for k = intersect(colrange(A,j), colsupport(B))
C[j] += α*inbands_getindex(A,k,j)'*B[k]
C[j] += inbands_getindex(A,k,j)' * B[k] * α
end
end
C
Expand Down Expand Up @@ -227,13 +227,13 @@ end

function materialize!(M::MatMulMatAdd{<:DiagonalLayout{<:AbstractFillLayout},<:AbstractBandedLayout})
checkdimensions(M)
M.C .= (M.α * getindex_value(M.A.diag)) .* M.B .+ M.β .* M.C
M.C .= getindex_value(M.A.diag) .* M.B .* M.α .+ M.C .* M.β
M.C
end

function materialize!(M::MatMulMatAdd{<:AbstractBandedLayout,<:DiagonalLayout{<:AbstractFillLayout}})
checkdimensions(M)
M.C .= (M.α * getindex_value(M.B.diag)) .* M.A .+ M.β .* M.C
M.C .= getindex_value(M.B.diag) .* M.A .* M.α .+ M.C .* M.β
M.C
end

Expand Down
4 changes: 4 additions & 0 deletions src/generic/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ prodbandwidths(A...) = broadcast(+, bandwidths.(A)...)
function sumbandwidths(A::AbstractMatrix, B::AbstractMatrix)
max(bandwidth(A, 1), bandwidth(B, 1)), max(bandwidth(A, 2), bandwidth(B, 2))
end


_fill_lmul!(β, A::AbstractArray{T}) where T = iszero(β) ? zero!(A) : lmul!(β, A)
_fill_rmul!(A::AbstractArray{T}, β) where T = iszero(β) ? zero!(A) : rmul!(A, β)
13 changes: 12 additions & 1 deletion test/test_linalg.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
using ArrayLayouts, BandedMatrices, FillArrays, LinearAlgebra, Test
using ArrayLayouts
using BandedMatrices
using FillArrays
using LinearAlgebra
using Quaternions
using Test

import Base.Broadcast: materialize, broadcasted
import BandedMatrices: BandedColumns, _BandedMatrix
Expand Down Expand Up @@ -109,6 +114,12 @@ ArrayLayouts.colsupport(::UnknownLayout, A::MyOneElement{<:Any,1}, _) =
@test B*M Matrix(B)*M
@test M*B M*Matrix(B)
end
@testset "non-commutative" begin
B1 = BandedMatrix(0 => [quat(rand(4)...) for i in 1:3])
v = [quat(rand(4)...) for i in 1:3]
α, β = quat(0,1,1,0), quat(1,0,0,1)
@test mul!(zero(v), B1, v, α, β) mul!(zero(v), Array(B1), v, α, β)
end
end

@testset "BandedMatrix * sparse" begin
Expand Down

0 comments on commit 56bfb23

Please sign in to comment.