From c7fc4513ccc6e7ef8d166834e5c31cf5288b585d Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Sat, 23 Dec 2023 15:58:28 +0100 Subject: [PATCH] Fix statistics tests. --- src/Manifolds.jl | 3 ++- src/manifolds/Euclidean.jl | 7 +++++-- test/statistics.jl | 4 ++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/Manifolds.jl b/src/Manifolds.jl index 6267506d0e..89c466e728 100644 --- a/src/Manifolds.jl +++ b/src/Manifolds.jl @@ -797,9 +797,10 @@ export ×, convert, complex_dot, decorated_manifold, - default_vector_transport_method, + default_approximation_method, default_inverse_retraction_method, default_retraction_method, + default_vector_transport_method, det_local_metric, differential_canonical_project, differential_canonical_project!, diff --git a/src/manifolds/Euclidean.jl b/src/manifolds/Euclidean.jl index 2abf194cc5..f5a5b3a69a 100644 --- a/src/manifolds/Euclidean.jl +++ b/src/manifolds/Euclidean.jl @@ -115,7 +115,10 @@ function check_vector(M::Euclidean{N,𝔽}, p, X; kwargs...) where {N,𝔽} end default_approximation_method(::Euclidean, ::typeof(mean)) = EfficientEstimator() -default_approximation_method(::Euclidean, ::typeof(median), ::Number) = EfficientEstimator() +function default_approximation_method(::Euclidean, ::typeof(median), ::Type{<:Number}) + return EfficientEstimator() +end + function default_approximation_method(::Euclidean, ::typeof(median), ::Array{T,0}) where {T} return EfficientEstimator() end @@ -610,7 +613,7 @@ function Statistics.median( end function Statistics.median( ::Union{Euclidean{TypeParameter{Tuple{}}},Euclidean{Tuple{}}}, - x::AbstractVector{<:Number}, + x::AbstractVector, w::AbstractWeights, ::EfficientEstimator; kwargs..., diff --git a/test/statistics.jl b/test/statistics.jl index b64f8a8758..09c50f9da3 100644 --- a/test/statistics.jl +++ b/test/statistics.jl @@ -133,7 +133,7 @@ function test_median( y = isnothing(method) ? median(M, x; kwargs...) : median(M, x, method; kwargs...) @test is_point(M, y; atol=10^-9) if yexp !== nothing - @test isapprox(M, y, yexp; atol=10^-5) + @test isapprox(M, y, yexp; atol=5 * 10^-5) end end @@ -797,7 +797,7 @@ end x = [normalize(randn(rng, 3)) for _ in 1:10] w = pweights([rand(rng) for _ in 1:length(x)]) m = normalize(median(Euclidean(3), x, w)) - mg = median(S, x, w, ExtrinsicEstimation(EfficientEstimator())) + mg = median(S, x, w, ExtrinsicEstimation(CyclicProximalPointEstimation())) @test isapprox(S, m, mg) end