Skip to content

Commit

Permalink
Continue sketching the TR subsolver
Browse files Browse the repository at this point in the history
but there is still some work to do, to make the new state constructors nice.
  • Loading branch information
kellertuer committed Oct 25, 2023
1 parent bc2b732 commit da87f25
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 37 deletions.
7 changes: 7 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ All notable Changes to the Julia package `Manopt.jl` will be documented in this
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.4.41] - dd/mm/yyyy

### Changed

`trust_regions` is now more flexible and the sub solver (Steinhaug-Toint tCG by default)
can now be exchanged.

## [0.4.40] – 24/10/2023

### Added
Expand Down
6 changes: 6 additions & 0 deletions src/solvers/difference-of-convex-proximal-point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ mutable struct DifferenceOfConvexProximalState{
)
end
end
# no point -> add point
function DifferenceOfConvexProximalState(
M::AbstractManifold, sub_problem, sub_state; kwargs...
)
return DifferenceOfConvexProximalState(M, rand(M), sub_problem, sub_state; kwargs...)
end
get_iterate(dcps::DifferenceOfConvexProximalState) = dcps.p
function set_iterate!(dcps::DifferenceOfConvexProximalState, M, p)
copyto!(M, dcps.p, p)
Expand Down
1 change: 1 addition & 0 deletions src/solvers/difference_of_convex_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ function difference_of_convex_algorithm!(
M, copy(M, p); stopping_criterion=sub_stopping_criterion
)
else
# TODO Fix constructor
TrustRegionsState(M, copy(M, p); stopping_criterion=sub_stopping_criterion)
end;
sub_kwargs...,
Expand Down
88 changes: 51 additions & 37 deletions src/solvers/trust_regions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ All the following fields (besides `p`) can be set by specifying them as keywords
used. This option is set to true in some scenarios to escape saddle
points, but is otherwise seldom activated.
* `ρ_regularization` – (`10000.0`) regularize the model fitness ``ρ`` to avoid division by zero
* `sub_problem` – an [`AbstractManoptProblem`](@ref) problem or a function `(M, p, X) -> q` or `(M, q, p, X)` for the a closed form solution of the sub problem
* `sub_state` – ([`TruncatedConjugateGradientState`](@ref)`(M, p, X)`)
* `σ` – (`0.0` or `1e-6` depending on `randomize`) Gaussian standard deviation when creating the random initial tangent vector
* `trust_region_radius` – (`max_trust_region_radius / 8`) the (initial) trust-region radius
Expand All @@ -46,13 +47,7 @@ keyword arguments.
[`trust_regions`](@ref)
"""
mutable struct TrustRegionsState{
P,
T,
SC<:StoppingCriterion,
RTR<:AbstractRetractionMethod,
R<:Real,
Proj,
Op<:AbstractHessianSolverState,
P,T,Pr,St,SC<:StoppingCriterion,RTR<:AbstractRetractionMethod,R<:Real,Proj
} <: AbstractHessianSolverState
p::P
X::T
Expand All @@ -64,7 +59,8 @@ mutable struct TrustRegionsState{
project!::Proj
acceptance_rate::R
ρ_regularization::R
sub_state::Op
sub_problem::Pr
sub_state::St
p_proposal::P
f_proposal::R
# Only required for Random mode Random
Expand All @@ -79,7 +75,7 @@ mutable struct TrustRegionsState{
reduction_factor::R
augmentation_threshold::R
augmentation_factor::R
function TrustRegionsState{P,T,SC,RTR,R,Proj,Op}(
function TrustRegionsState{P,T,Pr,St,SC,RTR,R,Proj}(
p::P,
X::T,
trust_region_radius::R,
Expand All @@ -91,21 +87,14 @@ mutable struct TrustRegionsState{
retraction_method::RTR,
reduction_threshold::R,
augmentation_threshold::R,
sub_state::Op,
sub_problem::Pr,
sub_state::St,
project!::Proj=copyto!,
reduction_factor=0.25,
augmentation_factor=2.0,
σ::R=random ? 1e-6 : 0.0,
) where {
P,
T,
SC<:StoppingCriterion,
RTR<:AbstractRetractionMethod,
R<:Real,
Proj,
Op<:AbstractHessianSolverState,
}
trs = new{P,T,SC,RTR,R,Proj,Op}()
) where {P,T,Pr,St,SC<:StoppingCriterion,RTR<:AbstractRetractionMethod,R<:Real,Proj}
trs = new{P,T,Pr,St,SC,RTR,R,Proj}()
trs.p = p
trs.X = X
trs.stop = stopping_citerion
Expand All @@ -115,6 +104,7 @@ mutable struct TrustRegionsState{
trs.acceptance_rate = acceptance_rate
trs.ρ_regularization = ρ_regularization
trs.randomize = randomize
trs.sub_problem = sub_problem
trs.sub_state = sub_state
trs.reduction_threshold = reduction_threshold
trs.reduction_factor = reduction_factor
Expand All @@ -125,11 +115,25 @@ mutable struct TrustRegionsState{
return trs
end
end
# No point no state -> add point
function TrustRegionsState(
M, sub_problem::Pr; kwargs...
) where {Pr<:Union{AbstractManoptProblem,<:Function}}
return TrustRegionsState(M, rand(M), sub_problem; kwargs...)
end
# No point but state -> add point
function TrustRegionsState(
M, sub_problem::Pr, sub_state::St; kwargs...
) where {Pr<:Union{AbstractManoptProblem,<:Function},St}
return TrustRegionsState(M, rand(M), sub_problem, sub_state; kwargs...)
end
# HessianObjective (problem) provided -> Constructor like in ALM?
function TrustRegionsState(
M::TM,
p::P=rand(M);
p::P,
sub_problem::Pr,
sub_state::St=TruncatedConjugateGradientState(M, p, X);
X::T=zero_vector(M, p),
sub_state::Op=TruncatedConjugateGradientState(M, p, X),
ρ_prime::R=0.1, #deprecated, remove on next breaking change
acceptance_rate=ρ_prime,
ρ_regularization::R=1000.0,
Expand All @@ -146,15 +150,16 @@ function TrustRegionsState(
σ=randomize ? 1e-4 : 0.0,
) where {
TM<:AbstractManifold,
Pr,
St,
P,
T,
R<:Real,
SC<:StoppingCriterion,
RTR<:AbstractRetractionMethod,
Proj,
Op<:AbstractHessianSolverState,
}
return TrustRegionsState{P,T,SC,RTR,R,Proj,Op}(
return TrustRegionsState{P,T,Pr,St,SC,RTR,R,Proj}(
p,
X,
trust_region_radius,
Expand All @@ -166,13 +171,17 @@ function TrustRegionsState(
retraction_method,
reduction_threshold,
augmentation_threshold,
sub_problem,
sub_state,
project!,
reduction_factor,
augmentation_factor,
σ,
)
end
# TODO Given the HessianObjective of the main task -> generate the sub_problem and state
# (default ones, similar to ALM, Check with the other subsolvers that they have the same)

get_iterate(trs::TrustRegionsState) = trs.p
function set_iterate!(trs::TrustRegionsState, M, p)
copyto!(M, trs.p, p)
Expand Down Expand Up @@ -253,8 +262,14 @@ function TrustRegionTangentSpaceModelObjective(
trust_region_radius::R=injectivity_radius(M) / 8,
gradient::T=get_gradient(M, mho, p),
bilinear_form::TH=nothing,
) where {TH<:Union{Function,Nothing},O<:AbstractManifoldHessianObjective,T,R}
return TrustRegionTangentSpaceModelObjective{TH,O,T,R}(
) where {
TH<:Union{Function,Nothing},
E<:AbstractEvaluationType,
O<:AbstractManifoldHessianObjective{E},
T,
R,
}
return TrustRegionTangentSpaceModelObjective{E,TH,O,T,R}(
mho, cost, gradient, bilinear_form, trust_region_radius
)
end
Expand Down Expand Up @@ -488,14 +503,11 @@ function trust_regions!(
reduction_factor::R=0.25,
augmentation_threshold::R=0.75,
augmentation_factor::R=2.0,
# ToDo – Tangent Space in Base? Implement TR Model, otherwise like below
sub_problem=DefaultManoptProblem(
TangentSpace(M, p),
TrustRegionTangentSpaceModelObjective(
M, mho, p; trust_region_radius=trust_region_radius
),
)TangentSpaceModelProblem(M, p, TrustRegionModel(mho)),
sub_state::AbstractHessianSolverState=TruncatedConjugateGradientState(
sub_objective=TrustRegionTangentSpaceModelObjective(
M, mho, p; trust_region_radius=trust_region_radius
),
sub_problem=DefaultManoptProblem(TangentSpace(M, p), sub_objective),
sub_state::Union{AbstractHessianSolverState,AbstractEvaluationType}=TruncatedConjugateGradientState(
M,
p,
zero_vector(M, p);
Expand All @@ -521,9 +533,10 @@ function trust_regions!(
dmp = DefaultManoptProblem(M, dmho)
trs = TrustRegionsState(
M,
p;
p,
sub_problem,
sub_state;
X=get_gradient(dmp, p),
sub_state=sub_state,
trust_region_radius=trust_region_radius,
max_trust_region_radius=max_trust_region_radius,
acceptance_rate=acceptance_rate,
Expand Down Expand Up @@ -578,9 +591,10 @@ function step_solver!(mp::AbstractManoptProblem, trs::TrustRegionsState, i)
# TODO provide these setters for the sub problem / sub state
# set_paramater!(trs.sub_problem, :Basepoint, trs.p)
set_manopt_parameter!(trs.sub_state, :Basepoint, trs.p)
set_manopt_parameter!(trs.sub_problem, :Basepoint, trs.p)
set_manopt_parameter!(trs.sub_state, :Iterate, trs.Y)
set_manopt_parameter!(trs.sub_state, :TrustRegionRadius, trs.trust_region_radius)
solve!(mp, trs.sub_state)
solve!(trs.sub_problem, trs.sub_state)
#
copyto!(M, trs.Y, trs.p, get_solver_result(trs.sub_state))
f = get_cost(mp, trs.p)
Expand Down

0 comments on commit da87f25

Please sign in to comment.