Skip to content

Commit

Permalink
generate a bit more test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Jan 21, 2024
1 parent 7c0e2df commit 6b0234f
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/Manifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ export Euclidean,
MultinomialDoubleStochastic,
MultinomialMatrices,
MultinomialSymmetric,
MultinomialSymmetricPositiveDefinite,
Oblique,
OrthogonalMatrices,
PositiveArrays,
Expand Down
15 changes: 7 additions & 8 deletions src/manifolds/MultinomialSymmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ i.e. is a symmetric matrix with positive entries whose rows sum to one.
function check_point(M::MultinomialSymmetric, p; kwargs...)
n = get_parameter(M.size)[1]
s = check_point(SymmetricMatrices(n, ℝ), p; kwargs...)
isnothing(s) && return s
!isnothing(s) && return s
s2 = check_point(MultinomialMatrices(n, n), p; kwargs...)
return s2

Check warning on line 62 in src/manifolds/MultinomialSymmetric.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/MultinomialSymmetric.jl#L59-L62

Added lines #L59 - L62 were not covered by tests
end
Expand All @@ -71,7 +71,7 @@ along any row.
function check_vector(M::MultinomialSymmetric, p, X; kwargs...)
n = get_parameter(M.size)[1]
s = check_vector(SymmetricMatrices(n, ℝ), p, X; kwargs...)
isnothing(s) && return s
!isnothing(s) && return s
s2 = check_vector(MultinomialMatrices(n, n), p, X)
return s2

Check warning on line 76 in src/manifolds/MultinomialSymmetric.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/MultinomialSymmetric.jl#L73-L76

Added lines #L73 - L76 were not covered by tests
end
Expand Down Expand Up @@ -155,13 +155,12 @@ function Random.rand!(
M::MultinomialSymmetric,
pX;
vector_at=nothing,
σ::Real=one(real(eltype(pX))),
kwargs...,
)
rand!(rng, pX)
pX .*= σ
n = get_parameter(M.size)[1]
rand!(rng, SymmetricMatrices(n), pX; kwargs...)
if vector_at === nothing
project!(M, pX, pX; kwargs...)
project!(M, pX, pX)

Check warning on line 163 in src/manifolds/MultinomialSymmetric.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/MultinomialSymmetric.jl#L160-L163

Added lines #L160 - L163 were not covered by tests
else
project!(M, pX, vector_at, pX)

Check warning on line 165 in src/manifolds/MultinomialSymmetric.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/MultinomialSymmetric.jl#L165

Added line #L165 was not covered by tests
end
Expand Down Expand Up @@ -203,10 +202,10 @@ function riemannian_Hessian!(M::MultinomialSymmetric, Y, p, G, H, X)
# with the small change their X is our p their ξ_X is our X , Hessf is H, Gradf is G
n = get_parameter(M.size)[1]
ov = ones(n) # \bf 1
I_p = lu(I .+ p)
I_p = lu(I + p)
γ = G .* p
α = I_p \* ov)
α_sq = (repeat(α, 1, n) .+ repeat', n, 1))
γ = G .* p
δ = γ .- α_sq .* p
γ_dot = H .* p + G .* X
α_dot = (I_p \ γ_dot .- (I_p \ X) * (I_p \ γ)) * ov
Expand Down
21 changes: 21 additions & 0 deletions src/manifolds/MultinomialSymmetricPositiveDefinite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,24 @@ function check_vector(M::MultinomialSymmetricPositiveDefinite, p, X; kwargs...)
s2 = check_vector(MultinomialMatrices(n, n), p, X)
return s2

Check warning on line 51 in src/manifolds/MultinomialSymmetricPositiveDefinite.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/MultinomialSymmetricPositiveDefinite.jl#L46-L51

Added lines #L46 - L51 were not covered by tests
end

function get_embedding(

Check warning on line 54 in src/manifolds/MultinomialSymmetricPositiveDefinite.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/MultinomialSymmetricPositiveDefinite.jl#L54

Added line #L54 was not covered by tests
::MultinomialSymmetricPositiveDefinite{TypeParameter{Tuple{n}}},
) where {n}
return MultinomialMatrices(n, n)

Check warning on line 57 in src/manifolds/MultinomialSymmetricPositiveDefinite.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/MultinomialSymmetricPositiveDefinite.jl#L57

Added line #L57 was not covered by tests
end
function get_embedding(M::MultinomialSymmetricPositiveDefinite{Tuple{Int}})
n = get_parameter(M.size)[1]
return MultinomialMatrices(n, n; parameter=:field)

Check warning on line 61 in src/manifolds/MultinomialSymmetricPositiveDefinite.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/MultinomialSymmetricPositiveDefinite.jl#L59-L61

Added lines #L59 - L61 were not covered by tests
end

function Base.show(

Check warning on line 64 in src/manifolds/MultinomialSymmetricPositiveDefinite.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/MultinomialSymmetricPositiveDefinite.jl#L64

Added line #L64 was not covered by tests
io::IO,
::MultinomialSymmetricPositiveDefinite{TypeParameter{Tuple{n}}},
) where {n}
return print(io, "MultinomialSymmetricPositiveDefinite($(n))")

Check warning on line 68 in src/manifolds/MultinomialSymmetricPositiveDefinite.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/MultinomialSymmetricPositiveDefinite.jl#L68

Added line #L68 was not covered by tests
end
function Base.show(io::IO, M::MultinomialSymmetricPositiveDefinite{Tuple{Int}})
n = get_parameter(M.size)[1]
return print(io, "MultinomialSymmetricPositiveDefinite($(n); parameter=:field)")

Check warning on line 72 in src/manifolds/MultinomialSymmetricPositiveDefinite.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/MultinomialSymmetricPositiveDefinite.jl#L70-L72

Added lines #L70 - L72 were not covered by tests
end
12 changes: 12 additions & 0 deletions src/manifolds/Symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,18 @@ project(::SymmetricMatrices, ::Any, ::Any)

project!(M::SymmetricMatrices, Y, p, X) = (Y .= (X .+ transpose(X)) ./ 2)

function Random.rand!(

Check warning on line 229 in src/manifolds/Symmetric.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/Symmetric.jl#L229

Added line #L229 was not covered by tests
rng::AbstractRNG,
M::SymmetricMatrices,
pX;
σ::Real=one(real(eltype(pX))),
kwargs...,
)
rand!(rng, pX)
pX .=/ (2 * norm(pX))) .* (pX + pX')
return pX

Check warning on line 238 in src/manifolds/Symmetric.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/Symmetric.jl#L236-L238

Added lines #L236 - L238 were not covered by tests
end

function Base.show(io::IO, ::SymmetricMatrices{TypeParameter{Tuple{n}},F}) where {n,F}
return print(io, "SymmetricMatrices($(n), $(F))")
end
Expand Down
15 changes: 15 additions & 0 deletions test/manifolds/multinomial_doubly_stochastic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,19 @@ include("../utils.jl")
@test repr(M) == "MultinomialDoubleStochastic(3; parameter=:field)"
@test get_embedding(M) === MultinomialMatrices(3, 3; parameter=:field)
end
@testset "random" begin
Random.seed!(42)
p = rand(M)
@test is_point(M, p)
X = rand(M; vector_at=p)
@test is_vector(M, p, X)
end
@testset "Riemannian Gradient" begin
M = MultinomialDoubleStochastic(3)
p = ones(3, 3) ./ 3
Y = [1.0; -1.0; 0.0 0.0; -1.0 1.0]
G = project(M, p, p .* Y)
X = riemannian_gradient(M, p, Y)
@test isapprox(M, p, G, X)
end
end
8 changes: 8 additions & 0 deletions test/manifolds/multinomial_matrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,12 @@ include("../utils.jl")
M = MultinomialMatrices(3, 2; parameter=:field)
@test repr(M) == "MultinomialMatrices(3, 2; parameter=:field)"
end
@testset "Riemannian Gradient" begin
M = MultinomialMatrices(3, 2)
p = [0.5 0.4 0.1; 0.5 0.4 0.1]'
Y = [1.0; -1.0; 0.0 0.0; -1.0 1.0]
G = project(M, p, p .* Y)
X = riemannian_gradient(M, p, Y)
@test isapprox(M, p, G, X)
end
end
12 changes: 12 additions & 0 deletions test/manifolds/multinomial_spd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
include("../utils.jl")

@testset "Multinomial symmetric positive definite matrices" begin
@testset "Basics" begin
M = MultinomialSymmetricPositiveDefinite(3)
Mf = MultinomialSymmetricPositiveDefinite(3; parameter=:field)
@test repr(M) == "MultinomialSymmetricPositiveDefinite(3)"
@test repr(Mf) == "MultinomialSymmetricPositiveDefinite(3; parameter=:field)"
@test get_embedding(M) == MultinomialMatrices(3, 3)
@test get_embedding(Mf) == MultinomialMatrices(3, 3; parameter=:field)
end
end
18 changes: 18 additions & 0 deletions test/manifolds/multinomial_symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,22 @@ include("../utils.jl")
@test repr(M) == "MultinomialSymmetric(3; parameter=:field)"
@test get_embedding(M) === MultinomialMatrices(3, 3; parameter=:field)
end
@testset "random" begin
Random.seed!(42)
p = rand(M)
@test is_point(M, p)
X = rand(M; vector_at=p)
@test is_vector(M, p, X)
end
@testset "Hessian call" begin
p = ones(3, 3) ./ 3
Y = one(p)
G = zero(p)
H = 0.5 * one(p)
X = riemannian_Hessian(M, p, G, H, Y)
X2 = similar(X)
riemannian_Hessian!(M, X, p, G, H, Y)
@test isapprox(M, p, X, X2)
@test is_vector(M, p, X)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ include("utils.jl")
include_test("manifolds/lorentz.jl")
include_test("manifolds/multinomial_doubly_stochastic.jl")
include_test("manifolds/multinomial_symmetric.jl")
include_test("manifolds/multinomial_spd.jl")
include_test("manifolds/positive_numbers.jl")
include_test("manifolds/probability_simplex.jl")
include_test("manifolds/projective_space.jl")
Expand Down

0 comments on commit 6b0234f

Please sign in to comment.