Skip to content

Commit

Permalink
readd sum checks. I am not sure where we lost those.
Browse files Browse the repository at this point in the history
reintroduce them from 2020.
  • Loading branch information
kellertuer committed Jan 17, 2024
1 parent 81a43d4 commit 0d7fb13
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions src/manifolds/MultinomialSymmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,22 @@ i.e. is a symmetric matrix with positive entries whose rows sum to one.
"""
function check_point(M::MultinomialSymmetric, p; kwargs...)
n = get_parameter(M.size)[1]
return check_point(SymmetricMatrices(n, ℝ), p)
s = check_point(SymmetricMatrices(n, ℝ), p)
isnothing(s) && return s
r = sum(p, dims=2)
if !isapprox(r, ones(n, 1); kwargs...)
return DomainError(
r,
"The point $(p) does not lie on $M, since its rows do not sum up to one.",
)
end
if !(minimum(p) > 0) || !(maximum(p) < 1)
return DomainError(
minimum(p),
"The point $(p) does not lie on $M, since at least one of its entries is nonpositive.",
)
end
return nothing
end
@doc raw"""
check_vector(M::MultinomialSymmetric p, X; kwargs...)
Expand All @@ -67,7 +82,16 @@ along any row.
"""
function check_vector(M::MultinomialSymmetric, p, X; kwargs...)
n = get_parameter(M.size)[1]
return check_vector(SymmetricMatrices(n, ℝ), p, X; kwargs...)
s = check_vector(SymmetricMatrices(n, ℝ), p, X; kwargs...)
isnothing(s) && return s
r = sum(X, dims=2) # due to symmetry, we only have to check columns
if !isapprox(r, zeros(n); kwargs...)
return DomainError(
r,
"The matrix $(X) is not a tangent vector to $(p) on $(M), since its columns/rows do not sum up to zero.",
)
end
return nothing
end

embed!(::MultinomialSymmetric, q, p) = copyto!(q, p)
Expand Down Expand Up @@ -115,8 +139,7 @@ The two vector $α ∈ ℝ^{n×n}$ is given by solving
````math
(I_n+p)α = Y\mathbf{1},
````
where $I_n$ is teh $n×n$ unit matrix and $\mathbf{1}_n$ is the vector of length $n$ containing ones.
where ``I_n`` is teh ``n×n`` unit matrix and ``\mathbf{1}_n`` is the vector of length ``n`` containing ones.
"""
project(::MultinomialSymmetric, ::Any, ::Any)

Expand Down

0 comments on commit 0d7fb13

Please sign in to comment.