diff --git a/src/manifolds/MultinomialSymmetric.jl b/src/manifolds/MultinomialSymmetric.jl index 23b178b368..ae89797cd4 100644 --- a/src/manifolds/MultinomialSymmetric.jl +++ b/src/manifolds/MultinomialSymmetric.jl @@ -56,7 +56,22 @@ i.e. is a symmetric matrix with positive entries whose rows sum to one. """ function check_point(M::MultinomialSymmetric, p; kwargs...) n = get_parameter(M.size)[1] - return check_point(SymmetricMatrices(n, ℝ), p) + s = check_point(SymmetricMatrices(n, ℝ), p) + isnothing(s) && return s + r = sum(p, dims=2) + if !isapprox(r, ones(n, 1); kwargs...) + return DomainError( + r, + "The point $(p) does not lie on $M, since its rows do not sum up to one.", + ) + end + if !(minimum(p) > 0) || !(maximum(p) < 1) + return DomainError( + minimum(p), + "The point $(p) does not lie on $M, since at least one of its entries is nonpositive.", + ) + end + return nothing end @doc raw""" check_vector(M::MultinomialSymmetric p, X; kwargs...) @@ -67,7 +82,16 @@ along any row. """ function check_vector(M::MultinomialSymmetric, p, X; kwargs...) n = get_parameter(M.size)[1] - return check_vector(SymmetricMatrices(n, ℝ), p, X; kwargs...) + s = check_vector(SymmetricMatrices(n, ℝ), p, X; kwargs...) + isnothing(s) && return s + r = sum(X, dims=2) # due to symmetry, we only have to check columns + if !isapprox(r, zeros(n); kwargs...) + return DomainError( + r, + "The matrix $(X) is not a tangent vector to $(p) on $(M), since its columns/rows do not sum up to zero.", + ) + end + return nothing end embed!(::MultinomialSymmetric, q, p) = copyto!(q, p) @@ -115,8 +139,7 @@ The two vector $α ∈ ℝ^{n×n}$ is given by solving ````math (I_n+p)α = Y\mathbf{1}, ```` -where $I_n$ is teh $n×n$ unit matrix and $\mathbf{1}_n$ is the vector of length $n$ containing ones. - +where ``I_n`` is teh ``n×n`` unit matrix and ``\mathbf{1}_n`` is the vector of length ``n`` containing ones. """ project(::MultinomialSymmetric, ::Any, ::Any)