Skip to content

Commit

Permalink
Refine factory.
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Dec 30, 2024
1 parent 2002ef5 commit e4be255
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 39 deletions.
56 changes: 20 additions & 36 deletions src/plans/nonlinear_least_squares_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,24 +150,6 @@ function get_cost(
)
end

function get_jacobian(
dmp::DefaultManoptProblem{mT,<:NonlinearLeastSquaresObjective}, p; kwargs...
) where {mT}
nlso = get_objective(dmp)
M = get_manifold(dmp)
J = zeros(length(nlso.objective), manifold_dimension(M))
get_jacobian!(M, J, nlso, p; kwargs...)
return J
end
function get_jacobian!(
dmp::DefaultManoptProblem{mT,<:NonlinearLeastSquaresObjective}, J, p; kwargs...
) where {mT}
nlso = get_objective(dmp)
M = get_manifold(dmp)
get_jacobian!(M, J, nlso, p; kwargs...)
return J
end

function get_jacobian(
M::AbstractManifold, nlso::NonlinearLeastSquaresObjective, p; kwargs...
)
Expand Down Expand Up @@ -446,6 +428,7 @@ mutable struct LevenbergMarquardtState{
end
end

function smoothing_factory end
"""
smoothing_factory(s::Symbol=:Identity)
smoothing_factory((s,α)::Tuple{Union{Symbol, ManifoldHessianObjective,<:Real})
Expand Down Expand Up @@ -489,14 +472,17 @@ containing all smoothing functions with their repetitions mentioned
Note that in the implementation the second derivative follows the general scheme of hessians
and actually implements s''(x)[X] = s''(x)X``.
"""
function smoothing_factory(s) end
smoothing_factory(s)

smoothing_factory() = smoothing_factory(:Identity)
smoothing_factory(o::ManifoldHessianObjective) = o
smoothing_factory(o::VectorHessianFunction) = o
function smoothing_factory(s::Symbol)
return ManifoldHessianObjective(_smoothing_factory(Val(s))...)
end
function _smoothing_factory(s::Symbol)
return _smoothing_factory(Val(s))
end
function smoothing_factory((s, α)::Tuple{Symbol,<:Real})
s, s_p, s_pp = _smoothing_factory(s, α)
return ManifoldHessianObjective(s, s_p, s_pp)
Expand Down Expand Up @@ -529,28 +515,22 @@ function smoothing_factory((o, k)::Tuple{ManifoldHessianObjective,<:Int})
hessian_type=ComponentVectorialType(),
)
end
function smoothing_factory(
S::NTuple{
n,
<:Union{
Symbol,
ManifoldHessianObjective,
Tuple{Symbol,<:Int},
Tuple{Symbol,<:Real},
Tuple{ManifoldHessianObjective,<:Int},
Tuple{ManifoldHessianObjective,<:Real},
},
} where {n},
)
function smoothing_factory(S...)
s = Function[]
s_p = Function[]
s_pp = Function[]
# collect all functions including their copies into a large vector
for t in S
_s, _s_p, _s_pp = _smoothing_factory(t...)
push!(s, _s...)
push!(s_p, _s_p...)
push!(s_pp, _s_pp...)
_s, _s_p, _s_pp = t isa Tuple ? _smoothing_factory(t...) : _smoothing_factory(t)
if _s isa Array
push!(s, _s...)
push!(s_p, _s_p...)
push!(s_pp, _s_pp...)
else
push!(s, _s)
push!(s_p, _s_p)
push!(s_pp, _s_pp)
end
end
k = length(s)
return VectorHessianFunction(
Expand All @@ -569,6 +549,10 @@ function _smoothing_factory(o::ManifoldHessianObjective)
(E, x) -> get_gradient(E, o, x),
(E, x, X) -> get_hessian(E, o, x, X)
end
function _smoothing_factory(o::VectorHessianFunction)
return o.value!!, o.jacobian!!, o.hessians!!
end

function _smoothing_factory(o::ManifoldHessianObjective, α::Real)
return (E, x) -> α^2 * get_cost(E, o, x / α^2),
(E, x) -> get_gradient(E, o, x / α^2),
Expand Down
13 changes: 10 additions & 3 deletions test/plans/test_nonlinear_least_squares_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,14 @@ using Manifolds, Manopt, Test
end
end
@testset "Smootthing factory" begin
s1 = Manopt.smoothing_factory(:Identity)
s1 = Manopt.smoothing_factory()
@test s1 isa ManifoldHessianObjective
s1s = Manopt.smoothing_factory((s1, 2.0))
@test s1s isa ManifoldHessianObjective
s1v = Manopt.smoothing_factory((s1, 3))
@test s1v isa VectorHessianFunction
@test length(s1v) == 3

@test Manopt.smoothing_factory(s1) === s1 # Passthrough for mhos
s2 = Manopt.smoothing_factory((:Identity, 2))
@test s2 isa VectorHessianFunction
Expand All @@ -145,8 +151,9 @@ using Manifolds, Manopt, Test
@test s4 isa ManifoldHessianObjective
end

s5 = Manopt.smoothing_factory(((:Identity, 2), (:Huber, 3)))
# Combine all different types
s5 = Manopt.smoothing_factory((:Identity, 2), (:Huber, 3), s1, :Tukey, s2)
@test s5 isa VectorHessianFunction
@test length(s5) == 5
@test length(s5) == 9
end
end

0 comments on commit e4be255

Please sign in to comment.