diff --git a/src/TensorCore.jl b/src/TensorCore.jl index e274d29..920f4c2 100644 --- a/src/TensorCore.jl +++ b/src/TensorCore.jl @@ -129,7 +129,7 @@ function tensor!(dest::AbstractArray, A::AbstractArray, B::AbstractArray) return dest end -export boxdot, ⊡, boxdot! +export boxdot, ⊡, ⊡₂, boxdot! """ boxdot(A,B) = A ⊡ B # \\boxdot @@ -177,40 +177,55 @@ Float64 ``` See also `boxdot!(Y,A,B)`, which is to `⊡` as `mul!` is to `*`. """ -function boxdot(A::AbstractArray, B::AbstractArray) - Amat = _squash_left(A) - Bmat = _squash_right(B) +function boxdot(A::AbstractArray, B::AbstractArray, nth::Val) + _check_boxdot_axes(A, B, nth) + Amat = _squash_left(A, nth) + Bmat = _squash_right(B, nth) axA, axB = axes(Amat,2), axes(Bmat,1) axA == axB || _throw_dmm(axA, axB) - return _boxdot_reshape(Amat * Bmat, A, B) + return _boxdot_reshape(Amat * Bmat, A, B, nth) end +boxdot(A::AbstractArray, B::AbstractArray) = boxdot(A, B, Val(1)) +boxdot2(A::AbstractArray, B::AbstractArray) = boxdot(A, B, Val(2)) + const ⊡ = boxdot +const ⊡₂ = boxdot2 @noinline _throw_dmm(axA, axB) = throw(DimensionMismatch("neighbouring axes of `A` and `B` must match, got $axA and $axB")) +@noinline _throw_boxdot_nth(n) = throw(ArgumentError("boxdot order should be ≥ 1, got $n")) + +function _check_boxdot_axes(A::AbstractArray{<:Any,N}, B::AbstractArray{<:Any,M}, ::Val{K}) where {N,M,K} + K::Int + (K >= 1) || _throw_boxdot_nth(K) + for i in 1:K + axA, axB = axes(A)[N-K+i], axes(B)[i] + axA == axB || _throw_dmm(axA, axB) + end +end -_squash_left(A::AbstractArray) = reshape(A, :,size(A,ndims(A))) -_squash_left(A::AbstractMatrix) = A +_squash_left(A::AbstractArray, ::Val{N}) where {N} = reshape(A, prod(size(A)[1:end-N]),:) +_squash_left(A::AbstractMatrix, ::Val{1}) = A -_squash_right(B::AbstractArray) = reshape(B, size(B,1),:) -_squash_right(B::AbstractVecOrMat) = B +_squash_right(B::AbstractArray, ::Val{N}) where {N} = reshape(B, :,prod(size(B)[1+N:end])) +_squash_right(B::AbstractVecOrMat, ::Val{1}) = B -function _boxdot_reshape(AB::AbstractArray, A::AbstractArray{T,N}, B::AbstractArray{S,M}) where {T,N,S,M} - ax = ntuple(i -> i i≤N-K ? axes(A, i) : axes(B, i-N+2K), Val(N+M-2K)) reshape(AB, ax) # some cases don't come here, so this doesn't really support OffsetArrays end # These can skip final reshape: -_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat) = AB +_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, ::Val) = AB # These produce scalar output: -function boxdot(A::AbstractVector, B::AbstractVector) - axA, axB = axes(A,1), axes(B,1) - axA == axB || _throw_dmm(axA, axB) +function boxdot(A::AbstractArray{<:Any,N}, B::AbstractArray{<:Any,N}, ::Val{N}) where {N} + _check_boxdot_axes(A, B, Val(N)) if eltype(A) <: Number - return transpose(A)*B + return transpose(vec(A))*vec(B) else return sum(a*b for (a,b) in zip(A,B)) end @@ -224,30 +239,30 @@ boxdot(a::Number, b::Number) = a*b using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec # Adjont and Transpose, vectors or almost (returning a scalar) -boxdot(A::AdjointAbsVec, B::AbstractVector) = A * B -boxdot(A::TransposeAbsVec, B::AbstractVector) = A * B +boxdot(A::AdjointAbsVec, B::AbstractVector, ::Val{1}) = A * B +boxdot(A::TransposeAbsVec, B::AbstractVector, ::Val{1}) = A * B -boxdot(A::AbstractVector, B::AdjointAbsVec) = A ⊡ vec(B) -boxdot(A::AbstractVector, B::TransposeAbsVec) = A ⊡ vec(B) +boxdot(A::AbstractVector, B::AdjointAbsVec, ::Val{1}) = A ⊡ vec(B) +boxdot(A::AbstractVector, B::TransposeAbsVec, ::Val{1}) = A ⊡ vec(B) -boxdot(A::AdjointAbsVec, B::AdjointAbsVec) = adjoint(adjoint(B) ⊡ adjoint(A)) -boxdot(A::AdjointAbsVec, B::TransposeAbsVec) = vec(A) ⊡ vec(B) -boxdot(A::TransposeAbsVec, B::AdjointAbsVec) = vec(A) ⊡ vec(B) -boxdot(A::TransposeAbsVec, B::TransposeAbsVec) = transpose(transpose(B) ⊡ transpose(A)) +boxdot(A::AdjointAbsVec, B::AdjointAbsVec, ::Val{1}) = adjoint(adjoint(B) ⊡ adjoint(A)) +boxdot(A::AdjointAbsVec, B::TransposeAbsVec, ::Val{1}) = vec(A) ⊡ vec(B) +boxdot(A::TransposeAbsVec, B::AdjointAbsVec, ::Val{1}) = vec(A) ⊡ vec(B) +boxdot(A::TransposeAbsVec, B::TransposeAbsVec, ::Val{1}) = transpose(transpose(B) ⊡ transpose(A)) # ... with a matrix (returning another such) -boxdot(A::AdjointAbsVec, B::AbstractMatrix) = A * B -boxdot(A::TransposeAbsVec, B::AbstractMatrix) = A * B +boxdot(A::AdjointAbsVec, B::AbstractMatrix, ::Val{1}) = A * B +boxdot(A::TransposeAbsVec, B::AbstractMatrix, ::Val{1}) = A * B -boxdot(A::AbstractMatrix, B::AdjointAbsVec) = (B' ⊡ A')' -boxdot(A::AbstractMatrix, B::TransposeAbsVec) = transpose(transpose(B) ⊡ transpose(A)) +boxdot(A::AbstractMatrix, B::AdjointAbsVec, ::Val{1}) = (B' ⊡ A')' +boxdot(A::AbstractMatrix, B::TransposeAbsVec, ::Val{1}) = transpose(transpose(B) ⊡ transpose(A)) # ... and with higher-dim (returning a plain array) -boxdot(A::AdjointAbsVec, B::AbstractArray) = vec(A) ⊡ B -boxdot(A::TransposeAbsVec, B::AbstractArray) = vec(A) ⊡ B +boxdot(A::AdjointAbsVec, B::AbstractArray, ::Val{1}) = vec(A) ⊡ B +boxdot(A::TransposeAbsVec, B::AbstractArray, ::Val{1}) = vec(A) ⊡ B -boxdot(A::AbstractArray, B::AdjointAbsVec) = A ⊡ vec(B) -boxdot(A::AbstractArray, B::TransposeAbsVec) = A ⊡ vec(B) +boxdot(A::AbstractArray, B::AdjointAbsVec, ::Val{1}) = A ⊡ vec(B) +boxdot(A::AbstractArray, B::TransposeAbsVec, ::Val{1}) = A ⊡ vec(B) """ @@ -260,25 +275,30 @@ function boxdot! end if VERSION < v"1.3" # Then 5-arg mul! isn't defined - function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray) - szY = prod(size(A)[1:end-1]), prod(size(B)[2:end]) - mul!(reshape(Y, szY), _squash_left(A), _squash_right(B)) + function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, ::Val{N}) where {N} + _check_boxdot_axes(A, B, Val(N)) + szY = prod(size(A)[1:end-N]), prod(size(B)[1+N:end]) + mul!(reshape(Y, szY), _squash_left(A, Val(N)), _squash_right(B, Val(N))) Y end - boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec) = boxdot!(Y, A, vec(B)) + boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray) = boxdot!(Y, A, B, Val(1)) + boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec) = boxdot!(Y, A, vec(B), Val(1)) else - function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, α::Number=true, β::Number=false) - szY = prod(size(A)[1:end-1]), prod(size(B)[2:end]) - mul!(reshape(Y, szY), _squash_left(A), _squash_right(B), α, β) + function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, ::Val{N}, α::Number=true, β::Number=false) where {N} + _check_boxdot_axes(A, B, Val(N)) + szY = prod(size(A)[1:end-N]), prod(size(B)[1+N:end]) + mul!(reshape(Y, szY), _squash_left(A, Val(N)), _squash_right(B, Val(N)), α, β) Y end + boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, α::Number=true, β::Number=false) = boxdot!(Y, A, B, Val(1), α, β) + # For boxdot!, only where mul! behaves differently: boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec, - α::Number=true, β::Number=false) = boxdot!(Y, A, vec(B), α, β) + α::Number=true, β::Number=false) = boxdot!(Y, A, vec(B), Val(1), α, β) end diff --git a/test/runtests.jl b/test/runtests.jl index 5657f2f..2e1aa4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -279,6 +279,59 @@ end @test boxdot!(similar(c,1,2), c', A) == c' * A @test boxdot!(similar(c,1), c', d) == [dot(c, d)] + + @testset "higher-order boxdot" begin + @test A ⊡₂ A isa Complex + @test boxdot(E3, E3, Val(3)) isa Complex + @test boxdot(F4, F4, Val(4)) isa Complex + @test A ⊡₂ A == sum(A .* A) + @test boxdot(E3, E3, Val(3)) == sum(E3 .* E3) + @test boxdot(F4, F4, Val(4)) == sum(F4 .* F4) + + @test size(A ⊡₂ E3) == (2,) + @test A ⊡₂ E3 == vec(reshape(A, 1,:) * reshape(E3, :,2)) + @test A ⊡₂ E3lazy == A ⊡₂ E3 + @test E3 ⊡₂ A' == vec((A ⊡₂ E3adjoint)') + @test E3 ⊡₂ transpose(A) == A ⊡₂ conj(E3adjoint) + + @test size(A ⊡₂ F4) == (2,2) + @test A ⊡₂ F4 == reshape(reshape(A, 1,:) * reshape(F4, :,4), 2,2) + @test A ⊡₂ F4lazy == A ⊡₂ F4 + @test F4lazy ⊡₂ A == F4 ⊡₂ A + + @test size(F4 ⊡₂ E3) == (2,2,2) + @test F4 ⊡₂ E3 == reshape(reshape(F4, 4,:) * reshape(E3, :,2), 2,2,2) + @test F4 ⊡₂ E3 == F4lazy ⊡₂ E3lazy + + # In-place + @test boxdot!(similar(c), A, E3, Val(2)) == A ⊡₂ E3 + if VERSION >= v"1.3" + @test boxdot!(similar(c), A, E3, Val(2), 100) == A ⊡₂ E3 * 100 + @test boxdot!(copy(c), B, E3, Val(2), 100, -5) == B ⊡₂ E3 * 100 .- 5 .* c + end + + @test boxdot!(similar(c,1), A, A, Val(2)) == [A ⊡₂ A] + @test boxdot!(similar(c,2,2), A, F4, Val(2)) == A ⊡₂ F4 + @test boxdot!(similar(c,2,2,2), F4, E3, Val(2)) == F4 ⊡₂ E3 + + # Errors + @test_throws DimensionMismatch ones(2,2) ⊡₂ ones(3,2) + @test_throws DimensionMismatch ones(2,2) ⊡₂ ones(2,3) + @test_throws DimensionMismatch ones(2,2,2) ⊡₂ ones(2,3,2) + @test_throws BoundsError ones(2,2) ⊡₂ ones(2) + @test_throws BoundsError ones(2) ⊡₂ ones(2,2) + @test_throws ArgumentError boxdot(ones(2), ones(2), Val(-1)) + @test_throws TypeError boxdot(ones(2), ones(2), Val(UInt(1))) + + @test_throws DimensionMismatch boxdot!(similar(c,1), ones(2,2), ones(3,2), Val(2)) + @test_throws DimensionMismatch boxdot!(similar(c,1), ones(2,2), ones(2,3), Val(2)) + @test_throws DimensionMismatch boxdot!(similar(c,2,2), ones(2,2,2), ones(2,3,2), Val(2)) + @test_throws BoundsError boxdot!(similar(c,1), ones(2,2), ones(2), Val(2)) + @test_throws BoundsError boxdot!(similar(c,1), ones(2), ones(2,2), Val(2)) + @test_throws DimensionMismatch boxdot!(similar(c,2,3), ones(2,2,3), ones(2,3,2), Val(2)) + @test_throws ArgumentError boxdot!(similar(c,1), ones(2), ones(2), Val(-1)) + @test_throws TypeError boxdot!(similar(c,1), ones(2), ones(2), Val(UInt(1))) + end end @testset "_adjoint" begin