From e4be25526cea106d38774a0664e72246c4c5c3ae Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Mon, 30 Dec 2024 16:00:22 +0100 Subject: [PATCH] Refine factory. --- src/plans/nonlinear_least_squares_plan.jl | 56 +++++++------------ .../test_nonlinear_least_squares_plan.jl | 13 ++++- 2 files changed, 30 insertions(+), 39 deletions(-) diff --git a/src/plans/nonlinear_least_squares_plan.jl b/src/plans/nonlinear_least_squares_plan.jl index 4cf927a92e..4df62d3537 100644 --- a/src/plans/nonlinear_least_squares_plan.jl +++ b/src/plans/nonlinear_least_squares_plan.jl @@ -150,24 +150,6 @@ function get_cost( ) end -function get_jacobian( - dmp::DefaultManoptProblem{mT,<:NonlinearLeastSquaresObjective}, p; kwargs... -) where {mT} - nlso = get_objective(dmp) - M = get_manifold(dmp) - J = zeros(length(nlso.objective), manifold_dimension(M)) - get_jacobian!(M, J, nlso, p; kwargs...) - return J -end -function get_jacobian!( - dmp::DefaultManoptProblem{mT,<:NonlinearLeastSquaresObjective}, J, p; kwargs... -) where {mT} - nlso = get_objective(dmp) - M = get_manifold(dmp) - get_jacobian!(M, J, nlso, p; kwargs...) - return J -end - function get_jacobian( M::AbstractManifold, nlso::NonlinearLeastSquaresObjective, p; kwargs... ) @@ -446,6 +428,7 @@ mutable struct LevenbergMarquardtState{ end end +function smoothing_factory end """ smoothing_factory(s::Symbol=:Identity) smoothing_factory((s,α)::Tuple{Union{Symbol, ManifoldHessianObjective,<:Real}) @@ -489,7 +472,7 @@ containing all smoothing functions with their repetitions mentioned Note that in the implementation the second derivative follows the general scheme of hessians and actually implements s''(x)[X] = s''(x)X``. """ -function smoothing_factory(s) end +smoothing_factory(s) smoothing_factory() = smoothing_factory(:Identity) smoothing_factory(o::ManifoldHessianObjective) = o @@ -497,6 +480,9 @@ smoothing_factory(o::VectorHessianFunction) = o function smoothing_factory(s::Symbol) return ManifoldHessianObjective(_smoothing_factory(Val(s))...) end +function _smoothing_factory(s::Symbol) + return _smoothing_factory(Val(s)) +end function smoothing_factory((s, α)::Tuple{Symbol,<:Real}) s, s_p, s_pp = _smoothing_factory(s, α) return ManifoldHessianObjective(s, s_p, s_pp) @@ -529,28 +515,22 @@ function smoothing_factory((o, k)::Tuple{ManifoldHessianObjective,<:Int}) hessian_type=ComponentVectorialType(), ) end -function smoothing_factory( - S::NTuple{ - n, - <:Union{ - Symbol, - ManifoldHessianObjective, - Tuple{Symbol,<:Int}, - Tuple{Symbol,<:Real}, - Tuple{ManifoldHessianObjective,<:Int}, - Tuple{ManifoldHessianObjective,<:Real}, - }, - } where {n}, -) +function smoothing_factory(S...) s = Function[] s_p = Function[] s_pp = Function[] # collect all functions including their copies into a large vector for t in S - _s, _s_p, _s_pp = _smoothing_factory(t...) - push!(s, _s...) - push!(s_p, _s_p...) - push!(s_pp, _s_pp...) + _s, _s_p, _s_pp = t isa Tuple ? _smoothing_factory(t...) : _smoothing_factory(t) + if _s isa Array + push!(s, _s...) + push!(s_p, _s_p...) + push!(s_pp, _s_pp...) + else + push!(s, _s) + push!(s_p, _s_p) + push!(s_pp, _s_pp) + end end k = length(s) return VectorHessianFunction( @@ -569,6 +549,10 @@ function _smoothing_factory(o::ManifoldHessianObjective) (E, x) -> get_gradient(E, o, x), (E, x, X) -> get_hessian(E, o, x, X) end +function _smoothing_factory(o::VectorHessianFunction) + return o.value!!, o.jacobian!!, o.hessians!! +end + function _smoothing_factory(o::ManifoldHessianObjective, α::Real) return (E, x) -> α^2 * get_cost(E, o, x / α^2), (E, x) -> get_gradient(E, o, x / α^2), diff --git a/test/plans/test_nonlinear_least_squares_plan.jl b/test/plans/test_nonlinear_least_squares_plan.jl index b6999c87d2..b317911740 100644 --- a/test/plans/test_nonlinear_least_squares_plan.jl +++ b/test/plans/test_nonlinear_least_squares_plan.jl @@ -129,8 +129,14 @@ using Manifolds, Manopt, Test end end @testset "Smootthing factory" begin - s1 = Manopt.smoothing_factory(:Identity) + s1 = Manopt.smoothing_factory() @test s1 isa ManifoldHessianObjective + s1s = Manopt.smoothing_factory((s1, 2.0)) + @test s1s isa ManifoldHessianObjective + s1v = Manopt.smoothing_factory((s1, 3)) + @test s1v isa VectorHessianFunction + @test length(s1v) == 3 + @test Manopt.smoothing_factory(s1) === s1 # Passthrough for mhos s2 = Manopt.smoothing_factory((:Identity, 2)) @test s2 isa VectorHessianFunction @@ -145,8 +151,9 @@ using Manifolds, Manopt, Test @test s4 isa ManifoldHessianObjective end - s5 = Manopt.smoothing_factory(((:Identity, 2), (:Huber, 3))) + # Combine all different types + s5 = Manopt.smoothing_factory((:Identity, 2), (:Huber, 3), s1, :Tukey, s2) @test s5 isa VectorHessianFunction - @test length(s5) == 5 + @test length(s5) == 9 end end