Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge identical methods for Symmetric/Hermitian and SymTridiagonal #56434

Merged
merged 2 commits into from
Nov 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 65 additions & 86 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,16 @@ convert(::Type{T}, m::Union{Symmetric,Hermitian}) where {T<:Hermitian} = m isa T

const HermOrSym{T, S} = Union{Hermitian{T,S}, Symmetric{T,S}}
const RealHermSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}}
const SymSymTri{T} = Union{Symmetric{T}, SymTridiagonal{T}}
const RealHermSymSymTri{T<:Real} = Union{RealHermSym{T}, SymTridiagonal{T}}
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}
const RealHermSymSymTriComplexHerm{T<:Real} = Union{RealHermSymComplexSym{T}, SymTridiagonal{T}}
const SelfAdjoint = Union{Symmetric{<:Real}, Hermitian{<:Number}}

wrappertype(::Union{Symmetric, SymTridiagonal}) = Symmetric
wrappertype(::Hermitian) = Hermitian

size(A::HermOrSym) = size(A.data)
axes(A::HermOrSym) = axes(A.data)
@inline function Base.isassigned(A::HermOrSym, i::Int, j::Int)
Expand Down Expand Up @@ -814,15 +820,15 @@ end
^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p)
^(A::SymTridiagonal{<:Real}, p::Integer) = sympow(A, p)
^(A::SymTridiagonal{<:Complex}, p::Integer) = sympow(A, p)
function sympow(A::SymSymTri, p::Integer)
if p < 0
return Symmetric(Base.power_by_squaring(inv(A), -p))
else
return Symmetric(Base.power_by_squaring(A, p))
end
end
for hermtype in (:Symmetric, :SymTridiagonal)
@eval begin
function sympow(A::$hermtype, p::Integer)
if p < 0
return Symmetric(Base.power_by_squaring(inv(A), -p))
else
return Symmetric(Base.power_by_squaring(A, p))
end
end
function ^(A::$hermtype{<:Real}, p::Real)
isinteger(p) && return integerpow(A, p)
F = eigen(A)
Expand All @@ -844,8 +850,8 @@ function ^(A::Hermitian, p::Integer)
else
retmat = Base.power_by_squaring(A, p)
end
for i = 1:size(A,1)
retmat[i,i] = real(retmat[i,i])
for i in diagind(retmat, IndexStyle(retmat))
retmat[i] = real(retmat[i])
end
return Hermitian(retmat)
end
Expand All @@ -857,8 +863,8 @@ function ^(A::Hermitian{T}, p::Real) where T
if T <: Real
return Hermitian(retmat)
else
for i = 1:size(A,1)
retmat[i,i] = real(retmat[i,i])
for i in diagind(retmat, IndexStyle(retmat))
retmat[i] = real(retmat[i])
end
return Hermitian(retmat)
end
Expand All @@ -873,34 +879,25 @@ function ^(A::Hermitian{T}, p::Real) where T
end

for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function ($func)(A::$hermtype{<:Real})
F = eigen(A)
return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
end
end
end
@eval begin
function ($func)(A::RealHermSymSymTri)
F = eigen(A)
return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
end
function ($func)(A::Hermitian{<:Complex})
n = checksquare(A)
F = eigen(A)
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
for i = 1:n
retmat[i,i] = real(retmat[i,i])
for i in diagind(retmat, IndexStyle(retmat))
retmat[i] = real(retmat[i])
end
return Hermitian(retmat)
end
end
end

for wrapper in (:Symmetric, :Hermitian, :SymTridiagonal)
@eval begin
function cis(A::$wrapper{<:Real})
F = eigen(A)
return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
end
end
function cis(A::RealHermSymSymTri)
F = eigen(A)
return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
end
function cis(A::Hermitian{<:Complex})
F = eigen(A)
Expand All @@ -909,26 +906,21 @@ end


for func in (:acos, :asin)
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function ($func)(A::$hermtype{<:Real})
F = eigen(A)
if all(λ -> -1 ≤ λ ≤ 1, F.values)
return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
end
@eval begin
function ($func)(A::RealHermSymSymTri)
F = eigen(A)
if all(λ -> -1 ≤ λ ≤ 1, F.values)
return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
end
end
end
@eval begin
function ($func)(A::Hermitian{<:Complex})
n = checksquare(A)
F = eigen(A)
if all(λ -> -1 ≤ λ ≤ 1, F.values)
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
for i = 1:n
retmat[i,i] = real(retmat[i,i])
for i in diagind(retmat, IndexStyle(retmat))
retmat[i] = real(retmat[i])
end
return Hermitian(retmat)
else
Expand All @@ -938,58 +930,49 @@ for func in (:acos, :asin)
end
end

for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function acosh(A::$hermtype{<:Real})
F = eigen(A)
if all(λ -> λ ≥ 1, F.values)
return $wrapper((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
end
end
function acosh(A::RealHermSymSymTri)
F = eigen(A)
if all(λ -> λ ≥ 1, F.values)
return wrappertype(A)((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
end
end
function acosh(A::Hermitian{<:Complex})
n = checksquare(A)
F = eigen(A)
if all(λ -> λ ≥ 1, F.values)
retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors'
for i = 1:n
retmat[i,i] = real(retmat[i,i])
for i in diagind(retmat, IndexStyle(retmat))
retmat[i] = real(retmat[i])
end
return Hermitian(retmat)
else
return (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors'
end
end

for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function sincos(A::$hermtype{<:Real})
n = checksquare(A)
F = eigen(A)
T = float(eltype(F.values))
S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,)))
for i in 1:n
S.diag[i], C.diag[i] = sincos(F.values[i])
end
return $wrapper((F.vectors * S) * F.vectors'), $wrapper((F.vectors * C) * F.vectors')
end
function sincos(A::RealHermSymSymTri)
n = checksquare(A)
F = eigen(A)
T = float(eltype(F.values))
S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,)))
for i in eachindex(S.diag, C.diag, F.values)
S.diag[i], C.diag[i] = sincos(F.values[i])
end
return wrappertype(A)((F.vectors * S) * F.vectors'), wrappertype(A)((F.vectors * C) * F.vectors')
end
function sincos(A::Hermitian{<:Complex})
n = checksquare(A)
F = eigen(A)
T = float(eltype(F.values))
S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,)))
for i in 1:n
for i in eachindex(S.diag, C.diag, F.values)
S.diag[i], C.diag[i] = sincos(F.values[i])
end
retmatS, retmatC = (F.vectors * S) * F.vectors', (F.vectors * C) * F.vectors'
for i = 1:n
retmatS[i,i] = real(retmatS[i,i])
retmatC[i,i] = real(retmatC[i,i])
for i in diagind(retmatS, IndexStyle(retmatS))
retmatS[i] = real(retmatS[i])
retmatC[i] = real(retmatC[i])
end
return Hermitian(retmatS), Hermitian(retmatC)
end
Expand All @@ -999,28 +982,24 @@ for func in (:log, :sqrt)
# sqrt has rtol arg to handle matrices that are semidefinite up to roundoff errors
rtolarg = func === :sqrt ? Any[Expr(:kw, :(rtol::Real), :(eps(real(float(one(T))))*size(A,1)))] : Any[]
rtolval = func === :sqrt ? :(-maximum(abs, F.values) * rtol) : 0
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function ($func)(A::$hermtype{T}; $(rtolarg...)) where {T<:Real}
F = eigen(A)
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
if all(λ -> λ ≥ λ₀, F.values)
return $wrapper((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
end
@eval begin
function ($func)(A::RealHermSymSymTri{T}; $(rtolarg...)) where {T<:Real}
F = eigen(A)
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
if all(λ -> λ ≥ λ₀, F.values)
return wrappertype(A)((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
end
end
end
@eval begin
function ($func)(A::Hermitian{T}; $(rtolarg...)) where {T<:Complex}
n = checksquare(A)
F = eigen(A)
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
if all(λ -> λ ≥ λ₀, F.values)
retmat = (F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors'
for i = 1:n
retmat[i,i] = real(retmat[i,i])
for i in diagind(retmat, IndexStyle(retmat))
retmat[i] = real(retmat[i])
end
return Hermitian(retmat)
else
Expand Down