Skip to content

Commit

Permalink
sqrt, cbrt and log for dense diagonal matrices (#1156)
Browse files Browse the repository at this point in the history
This PR improves performance by only applying the functions to the
diagonal elements:
```julia
julia> A = diagm(0=>ones(100));

julia> @Btime log($A);
  364.163 μs (22 allocations: 401.62 KiB) # master
  13.528 μs (7 allocations: 80.02 KiB) # this PR
```
Similar improvements for `sqrt` and `cbrt` as well.
  • Loading branch information
jishnub authored Dec 23, 2024
1 parent 959d985 commit 6e5ea12
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,9 @@ julia> log(A)
"""
function log(A::AbstractMatrix)
# If possible, use diagonalization
if ishermitian(A)
if isdiag(A)
return applydiagonal(log, A)
elseif ishermitian(A)
logHermA = log(Hermitian(A))
return ishermitian(logHermA) ? copytri!(parent(logHermA), 'U', true) : parent(logHermA)
elseif istriu(A)
Expand Down Expand Up @@ -969,7 +971,9 @@ sqrt(::AbstractMatrix)

function sqrt(A::AbstractMatrix{T}) where {T<:Union{Real,Complex}}
if checksquare(A) == 0
return copy(A)
return copy(float(A))
elseif isdiag(A)
return applydiagonal(sqrt, A)
elseif ishermitian(A)
sqrtHermA = sqrt(Hermitian(A))
return ishermitian(sqrtHermA) ? copytri!(parent(sqrtHermA), 'U', true) : parent(sqrtHermA)
Expand Down Expand Up @@ -1035,7 +1039,9 @@ true
"""
function cbrt(A::AbstractMatrix{<:Real})
if checksquare(A) == 0
return copy(A)
return copy(float(A))
elseif isdiag(A)
return applydiagonal(cbrt, A)
elseif issymmetric(A)
return cbrt(Symmetric(A, :U))
else
Expand Down
16 changes: 16 additions & 0 deletions test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,7 @@ end

A13 = convert(Matrix{elty}, [2 0; 0 2])
@test typeof(log(A13)) == Array{elty, 2}
@test exp(log(A13)) log(exp(A13)) A13

T = elty == Float64 ? Symmetric : Hermitian
@test typeof(log(T(A13))) == T{elty, Array{elty, 2}}
Expand Down Expand Up @@ -968,6 +969,10 @@ end
@test typeof(sqrt(A8)) == Matrix{elty}
end
end
@testset "sqrt for diagonal" begin
A = diagm(0 => [1, 2, 3])
@test sqrt(A)^2 A
end

@testset "issue #40141" begin
x = [-1 -eps() 0 0; eps() -1 0 0; 0 0 -1 -eps(); 0 0 eps() -1]
Expand Down Expand Up @@ -1280,6 +1285,7 @@ end
T = cbrt(Symmetric(S,:U))
@test T*T*T S
@test eltype(S) == eltype(T)
@test cbrt(Array(Symmetric(S,:U))) == T
# Real valued symmetric
S = (A -> (A+A')/2)(randn(N,N))
T = cbrt(Symmetric(S,:L))
Expand All @@ -1300,6 +1306,16 @@ end
T = cbrt(A)
@test T*T*T A
@test eltype(A) == eltype(T)
@testset "diagonal" begin
A = diagm(0 => [1, 2, 3])
@test cbrt(A)^3 A
end
@testset "empty" begin
A = Matrix{Float64}(undef, 0, 0)
@test cbrt(A) == A
A = Matrix{Int}(undef, 0, 0)
@test cbrt(A) isa Matrix{Float64}
end
end

@testset "tr" begin
Expand Down

0 comments on commit 6e5ea12

Please sign in to comment.