Skip to content

Commit

Permalink
Unify definition and usage of sub_kwargs in all 7 solvers with sub so…
Browse files Browse the repository at this point in the history
…lvers.
  • Loading branch information
kellertuer committed Dec 25, 2023
1 parent 2d829b2 commit 7d25f71
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 22 deletions.
5 changes: 3 additions & 2 deletions src/solvers/FrankWolfe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ the [`AbstractManifoldGradientObjective`](@ref) `gradient_objective` directly.
For points 2 and 3 the `sub_state` has to be set to the corresponding [`AbstractEvaluationType`](@ref), [`AllocatingEvaluation`](@ref) and [`InplaceEvaluation`](@ref), respectively
* `sub_state` - (`evaluation` if `sub_problem` is a function, a decorated [`GradientDescentState`](@ref) otherwise)
for a function, the evaluation is inherited from the Frank-Wolfe `evaluation` keyword.
* `sub_kwargs` - (`[]`) – keyword arguments to decorate the `sub_state` default state in case the sub_problem is not a function
* `sub_kwargs` - (`(;)`) – keyword arguments to decorate the `sub_state` default state in case the sub_problem is not a function
All other keyword arguments are passed to [`decorate_state!`](@ref) for decorators or
[`decorate_objective!`](@ref), respectively.
Expand Down Expand Up @@ -259,7 +259,7 @@ function Frank_Wolfe_method!(
StopWhenChangeLess(1.0e-8),
sub_cost=FrankWolfeCost(p, initial_vector),
sub_grad=FrankWolfeGradient(p, initial_vector),
sub_kwargs=[],
sub_kwargs=(;),
sub_objective=ManifoldGradientObjective(sub_cost, sub_grad),
sub_problem=DefaultManoptProblem(
M,
Expand All @@ -278,6 +278,7 @@ function Frank_Wolfe_method!(
stepsize=default_stepsize(
M, GradientDescentState; retraction_method=retraction_method
),
sub_kwargs...,
);
objective_type=objective_type,
sub_kwargs...,
Expand Down
22 changes: 14 additions & 8 deletions src/solvers/adaptive_regularization_with_cubics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,15 +391,19 @@ function adaptive_regularization_with_cubics!(
γ1::R=0.1,
γ2::R=2.0,
θ::R=0.5,
sub_kwargs=[],
sub_kwargs=(;),
sub_stopping_criterion::StoppingCriterion=StopAfterIteration(maxIterLanczos) |
StopWhenFirstOrderProgress(θ),
sub_state::Union{<:AbstractManoptSolverState,<:AbstractEvaluationType}=LanczosState(
TangentSpace(M, copy(M, p));
maxIterLanczos=maxIterLanczos,
σ=σ,
θ=θ,
stopping_criterion=sub_stopping_criterion,
sub_state::Union{<:AbstractManoptSolverState,<:AbstractEvaluationType}=decorate_state!(
LanczosState(
TangentSpace(M, copy(M, p));
maxIterLanczos=maxIterLanczos,
σ=σ,
θ=θ,
stopping_criterion=sub_stopping_criterion,
sub_kwargs...,
);
sub_kwargs,
),
sub_objective=nothing,
sub_problem=nothing,
Expand All @@ -414,7 +418,9 @@ function adaptive_regularization_with_cubics!(
) where {T,R,O<:Union{ManifoldHessianObjective,AbstractDecoratedManifoldObjective}}
dmho = decorate_objective!(M, mho; objective_type=objective_type, kwargs...)
if isnothing(sub_objective)
sub_objective = AdaptiveRagularizationWithCubicsModelObjective(dmho, σ)
sub_objective = decorate_objective!(
M, AdaptiveRagularizationWithCubicsModelObjective(dmho, σ); sub_kwargs...
)
end
if isnothing(sub_problem)
sub_problem = DefaultManoptProblem(TangentSpace(M, copy(M, p)), sub_objective)
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/augmented_Lagrangian_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ function augmented_Lagrangian_method!(
objective_type=:Riemannian,
sub_cost=AugmentedLagrangianCost(cmo, ρ, μ, λ),
sub_grad=AugmentedLagrangianGrad(cmo, ρ, μ, λ),
sub_kwargs=[],
sub_kwargs=(;),
sub_stopping_criterion=StopAfterIteration(300) |
StopWhenGradientNormLess(ϵ) |
StopWhenStepsizeLess(1e-8),
Expand Down
8 changes: 4 additions & 4 deletions src/solvers/difference-of-convex-proximal-point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ difference_of_convex_proximal_point(M, grad_h, p0; g=g, grad_g=grad_g)
This is generated by default when `grad_g` is provided. You can specify your own by overwriting this keyword.
* `sub_hess` – (a finite difference approximation by default) specify
a Hessian of the subproblem, which the default solver, see `sub_state` needs
* `sub_kwargs` – (`[]`) pass keyword arguments to the `sub_state`, in form of
* `sub_kwargs` – (`(;)`) pass keyword arguments to the `sub_state`, in form of
a `Dict(:kwname=>value)`, unless you set the `sub_state` directly.
* `sub_objective` – (a gradient or hessian objective based on the last 3 keywords)
provide the objective used within `sub_problem` (if that is not specified by the user)
Expand Down Expand Up @@ -331,7 +331,7 @@ function difference_of_convex_proximal_point!(
ProximalDCGrad(grad_g, copy(M, p), λ(1); evaluation=evaluation)
end,
sub_hess=ApproxHessianFiniteDifference(M, copy(M, p), sub_grad; evaluation=evaluation),
sub_kwargs=[],
sub_kwargs=(;),
sub_stopping_criterion=StopAfterIteration(300) | StopWhenGradientNormLess(1e-8),
sub_objective=if isnothing(sub_cost) || isnothing(sub_grad)
nothing
Expand Down Expand Up @@ -368,7 +368,7 @@ function difference_of_convex_proximal_point!(
decorate_state!(
if isnothing(sub_hess)
GradientDescentState(
M, copy(M, p); stopping_criterion=sub_stopping_criterion
M, copy(M, p); stopping_criterion=sub_stopping_criterion, sub_kwargs...
)
else
TrustRegionsState(
Expand All @@ -378,7 +378,7 @@ function difference_of_convex_proximal_point!(
TangentSpace(M, copy(M, p)),
TrustRegionModelObjective(sub_objective),
),
TruncatedConjugateGradientState(TangentSpace(M, p)),
TruncatedConjugateGradientState(TangentSpace(M, p); sub_kwargs...),
)
end;
sub_kwargs...,
Expand Down
12 changes: 8 additions & 4 deletions src/solvers/difference_of_convex_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ difference_of_convex_algorithm(M, f, g, grad_h, p; grad_g=grad_g)
This is generated by default when `grad_g` is provided. You can specify your own by overwriting this keyword.
* `sub_hess` – (a finite difference approximation by default) specify a Hessian
of the subproblem, which the default solver, see `sub_state` needs
* `sub_kwargs` - (`[]`) pass keyword arguments to the `sub_state`, in form of
* `sub_kwargs` - (`(;)`) pass keyword arguments to the `sub_state`, in form of
a `Dict(:kwname=>value)`, unless you set the `sub_state` directly.
* `sub_objective` - (a gradient or hessian objective based on the last 3 keywords)
provide the objective used within `sub_problem` (if that is not specified by the user)
Expand Down Expand Up @@ -301,7 +301,7 @@ function difference_of_convex_algorithm!(
)
end,
sub_hess=ApproxHessianFiniteDifference(M, copy(M, p), sub_grad; evaluation=evaluation),
sub_kwargs=[],
sub_kwargs=(;),
sub_stopping_criterion=StopAfterIteration(300) | StopWhenGradientNormLess(1e-8),
sub_objective=if isnothing(sub_cost) || isnothing(sub_grad)
nothing
Expand Down Expand Up @@ -333,11 +333,15 @@ function difference_of_convex_algorithm!(
decorate_state!(
if isnothing(sub_hess)
GradientDescentState(
M, copy(M, p); stopping_criterion=sub_stopping_criterion
M, copy(M, p); stopping_criterion=sub_stopping_criterion, sub_kwargs...
)
else
TrustRegionsState(
M, copy(M, p), sub_objective; stopping_criterion=sub_stopping_criterion
M,
copy(M, p),
sub_objective;
stopping_criterion=sub_stopping_criterion,
sub_kwargs...,
)
end;
sub_kwargs...,
Expand Down
3 changes: 2 additions & 1 deletion src/solvers/exact_penalty_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ function exact_penalty_method!(
smoothing=LogarithmicSumOfExponentials(),
sub_cost=ExactPenaltyCost(cmo, ρ, u; smoothing=smoothing),
sub_grad=ExactPenaltyGrad(cmo, ρ, u; smoothing=smoothing),
sub_kwargs=[],
sub_kwargs=(;),
sub_problem::AbstractManoptProblem=DefaultManoptProblem(
M,
decorate_objective!(
Expand All @@ -324,6 +324,7 @@ function exact_penalty_method!(
),
stopping_criterion=sub_stopping_criterion,
stepsize=default_stepsize(M, QuasiNewtonState),
sub_kwargs...,
);
sub_kwargs...,
),
Expand Down
4 changes: 2 additions & 2 deletions src/solvers/trust_regions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,8 @@ function trust_regions!(
reduction_factor::R=0.25,
augmentation_threshold::R=0.75,
augmentation_factor::R=2.0,
sub_kwargs=[],
sub_objective=TrustRegionModelObjective(mho),
sub_kwargs=(;),
sub_objective=decorate_objective!(M, TrustRegionModelObjective(mho), sub_kwargs...),
sub_problem=DefaultManoptProblem(TangentSpace(M, p), sub_objective),
sub_stopping_criterion::StoppingCriterion=StopAfterIteration(manifold_dimension(M)) |
StopWhenResidualIsReducedByFactorOrPower(;
Expand Down

0 comments on commit 7d25f71

Please sign in to comment.