diff --git a/src/Manifolds.jl b/src/Manifolds.jl index 15116d975e..3102fa7981 100644 --- a/src/Manifolds.jl +++ b/src/Manifolds.jl @@ -640,6 +640,7 @@ export Euclidean, MultinomialDoubleStochastic, MultinomialMatrices, MultinomialSymmetric, + MultinomialSymmetricPositiveDefinite, Oblique, OrthogonalMatrices, PositiveArrays, diff --git a/src/manifolds/MultinomialSymmetric.jl b/src/manifolds/MultinomialSymmetric.jl index d85e195159..f68b3d7ce0 100644 --- a/src/manifolds/MultinomialSymmetric.jl +++ b/src/manifolds/MultinomialSymmetric.jl @@ -57,7 +57,7 @@ i.e. is a symmetric matrix with positive entries whose rows sum to one. function check_point(M::MultinomialSymmetric, p; kwargs...) n = get_parameter(M.size)[1] s = check_point(SymmetricMatrices(n, ℝ), p; kwargs...) - isnothing(s) && return s + !isnothing(s) && return s s2 = check_point(MultinomialMatrices(n, n), p; kwargs...) return s2 end @@ -71,7 +71,7 @@ along any row. function check_vector(M::MultinomialSymmetric, p, X; kwargs...) n = get_parameter(M.size)[1] s = check_vector(SymmetricMatrices(n, ℝ), p, X; kwargs...) - isnothing(s) && return s + !isnothing(s) && return s s2 = check_vector(MultinomialMatrices(n, n), p, X) return s2 end @@ -155,13 +155,12 @@ function Random.rand!( M::MultinomialSymmetric, pX; vector_at=nothing, - σ::Real=one(real(eltype(pX))), kwargs..., ) - rand!(rng, pX) - pX .*= σ + n = get_parameter(M.size)[1] + rand!(rng, SymmetricMatrices(n), pX; kwargs...) if vector_at === nothing - project!(M, pX, pX; kwargs...) + project!(M, pX, pX) else project!(M, pX, vector_at, pX) end @@ -203,10 +202,10 @@ function riemannian_Hessian!(M::MultinomialSymmetric, Y, p, G, H, X) # with the small change their X is our p their ξ_X is our X , Hessf is H, Gradf is G n = get_parameter(M.size)[1] ov = ones(n) # \bf 1 - I_p = lu(I .+ p) + I_p = lu(I + p) + γ = G .* p α = I_p \ (γ * ov) α_sq = (repeat(α, 1, n) .+ repeat(α', n, 1)) - γ = G .* p δ = γ .- α_sq .* p γ_dot = H .* p + G .* X α_dot = (I_p \ γ_dot .- (I_p \ X) * (I_p \ γ)) * ov diff --git a/src/manifolds/MultinomialSymmetricPositiveDefinite.jl b/src/manifolds/MultinomialSymmetricPositiveDefinite.jl index c485b2e2d2..5f87a62936 100644 --- a/src/manifolds/MultinomialSymmetricPositiveDefinite.jl +++ b/src/manifolds/MultinomialSymmetricPositiveDefinite.jl @@ -50,3 +50,24 @@ function check_vector(M::MultinomialSymmetricPositiveDefinite, p, X; kwargs...) s2 = check_vector(MultinomialMatrices(n, n), p, X) return s2 end + +function get_embedding( + ::MultinomialSymmetricPositiveDefinite{TypeParameter{Tuple{n}}}, +) where {n} + return MultinomialMatrices(n, n) +end +function get_embedding(M::MultinomialSymmetricPositiveDefinite{Tuple{Int}}) + n = get_parameter(M.size)[1] + return MultinomialMatrices(n, n; parameter=:field) +end + +function Base.show( + io::IO, + ::MultinomialSymmetricPositiveDefinite{TypeParameter{Tuple{n}}}, +) where {n} + return print(io, "MultinomialSymmetricPositiveDefinite($(n))") +end +function Base.show(io::IO, M::MultinomialSymmetricPositiveDefinite{Tuple{Int}}) + n = get_parameter(M.size)[1] + return print(io, "MultinomialSymmetricPositiveDefinite($(n); parameter=:field)") +end diff --git a/src/manifolds/Symmetric.jl b/src/manifolds/Symmetric.jl index 989d3dba29..2fd525e16c 100644 --- a/src/manifolds/Symmetric.jl +++ b/src/manifolds/Symmetric.jl @@ -226,6 +226,18 @@ project(::SymmetricMatrices, ::Any, ::Any) project!(M::SymmetricMatrices, Y, p, X) = (Y .= (X .+ transpose(X)) ./ 2) +function Random.rand!( + rng::AbstractRNG, + M::SymmetricMatrices, + pX; + σ::Real=one(real(eltype(pX))), + kwargs..., +) + rand!(rng, pX) + pX .= (σ / (2 * norm(pX))) .* (pX + pX') + return pX +end + function Base.show(io::IO, ::SymmetricMatrices{TypeParameter{Tuple{n}},F}) where {n,F} return print(io, "SymmetricMatrices($(n), $(F))") end diff --git a/test/manifolds/multinomial_doubly_stochastic.jl b/test/manifolds/multinomial_doubly_stochastic.jl index e4b8d00116..34e71e0ad0 100644 --- a/test/manifolds/multinomial_doubly_stochastic.jl +++ b/test/manifolds/multinomial_doubly_stochastic.jl @@ -68,4 +68,19 @@ include("../utils.jl") @test repr(M) == "MultinomialDoubleStochastic(3; parameter=:field)" @test get_embedding(M) === MultinomialMatrices(3, 3; parameter=:field) end + @testset "random" begin + Random.seed!(42) + p = rand(M) + @test is_point(M, p) + X = rand(M; vector_at=p) + @test is_vector(M, p, X) + end + @testset "Riemannian Gradient" begin + M = MultinomialDoubleStochastic(3) + p = ones(3, 3) ./ 3 + Y = [1.0; -1.0; 0.0 0.0; -1.0 1.0] + G = project(M, p, p .* Y) + X = riemannian_gradient(M, p, Y) + @test isapprox(M, p, G, X) + end end diff --git a/test/manifolds/multinomial_matrices.jl b/test/manifolds/multinomial_matrices.jl index c04d07410f..28fb8cb7e5 100644 --- a/test/manifolds/multinomial_matrices.jl +++ b/test/manifolds/multinomial_matrices.jl @@ -55,4 +55,12 @@ include("../utils.jl") M = MultinomialMatrices(3, 2; parameter=:field) @test repr(M) == "MultinomialMatrices(3, 2; parameter=:field)" end + @testset "Riemannian Gradient" begin + M = MultinomialMatrices(3, 2) + p = [0.5 0.4 0.1; 0.5 0.4 0.1]' + Y = [1.0; -1.0; 0.0 0.0; -1.0 1.0] + G = project(M, p, p .* Y) + X = riemannian_gradient(M, p, Y) + @test isapprox(M, p, G, X) + end end diff --git a/test/manifolds/multinomial_spd.jl b/test/manifolds/multinomial_spd.jl new file mode 100644 index 0000000000..f79bcb14fe --- /dev/null +++ b/test/manifolds/multinomial_spd.jl @@ -0,0 +1,12 @@ +include("../utils.jl") + +@testset "Multinomial symmetric positive definite matrices" begin + @testset "Basics" begin + M = MultinomialSymmetricPositiveDefinite(3) + Mf = MultinomialSymmetricPositiveDefinite(3; parameter=:field) + @test repr(M) == "MultinomialSymmetricPositiveDefinite(3)" + @test repr(Mf) == "MultinomialSymmetricPositiveDefinite(3; parameter=:field)" + @test get_embedding(M) == MultinomialMatrices(3, 3) + @test get_embedding(Mf) == MultinomialMatrices(3, 3; parameter=:field) + end +end diff --git a/test/manifolds/multinomial_symmetric.jl b/test/manifolds/multinomial_symmetric.jl index 58450b5bac..f1c862aaa0 100644 --- a/test/manifolds/multinomial_symmetric.jl +++ b/test/manifolds/multinomial_symmetric.jl @@ -73,4 +73,22 @@ include("../utils.jl") @test repr(M) == "MultinomialSymmetric(3; parameter=:field)" @test get_embedding(M) === MultinomialMatrices(3, 3; parameter=:field) end + @testset "random" begin + Random.seed!(42) + p = rand(M) + @test is_point(M, p) + X = rand(M; vector_at=p) + @test is_vector(M, p, X) + end + @testset "Hessian call" begin + p = ones(3, 3) ./ 3 + Y = one(p) + G = zero(p) + H = 0.5 * one(p) + X = riemannian_Hessian(M, p, G, H, Y) + X2 = similar(X) + riemannian_Hessian!(M, X, p, G, H, Y) + @test isapprox(M, p, X, X2) + @test is_vector(M, p, X) + end end diff --git a/test/runtests.jl b/test/runtests.jl index afb113cc85..69895b4427 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -148,6 +148,7 @@ include("utils.jl") include_test("manifolds/lorentz.jl") include_test("manifolds/multinomial_doubly_stochastic.jl") include_test("manifolds/multinomial_symmetric.jl") + include_test("manifolds/multinomial_spd.jl") include_test("manifolds/positive_numbers.jl") include_test("manifolds/probability_simplex.jl") include_test("manifolds/projective_space.jl")