Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduced an entry change stopping criterion and generalise the gradient norm SC. #345

Merged
merged 12 commits into from
Jan 18, 2024
Merged
11 changes: 10 additions & 1 deletion Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@ All notable Changes to the Julia package `Manopt.jl` will be documented in this
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.4.48]
## [0.4.49] January 17, 2024

### Added

* A `StopWhenEntryChangeLess` to be able to stop on arbitrary small changes of specific fields
* generalises `StopWhenGradientNormLess` to accept arbitrary `norm=` functions
* refactor the default in `particle_swarm` to no longer “misuse” the iteration change check,
but actually the new one one the `:swarm` entry

## [0.4.48] January 16, 2024

### Fixed

Expand Down
1 change: 1 addition & 0 deletions src/Manopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ export StopAfter,
StopWhenChangeLess,
StopWhenCostLess,
StopWhenCurvatureIsNegative,
StopWhenEntryChangeLess,
StopWhenGradientChangeLess,
StopWhenGradientNormLess,
StopWhenFirstOrderProgress,
Expand Down
126 changes: 115 additions & 11 deletions src/plans/stopping_criterion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,91 @@ function update_stopping_criterion!(c::StopWhenCostLess, ::Val{:MinCost}, v)
return c
end

@doc raw"""
StopWhenEntryChangeLess

Evaluate whether a certain fields change is less than a certain threshold

## Fields

* `field` – a symbol adressing the corresponding field in a certain subtype of [`AbstractManoptSolverState`](@ref)
to track
* `distance` – a function `(problem, state, v1, v2) -> R` that computes the distance between two possible values of the `field`
* `storage` – a [`StoreStateAction`](@ref) to store the previous value of the `field`
* `threshold` – the threshold to indicate to stop when the distance is below this value

# Internal fields

* `reason` – store a string reason when the stop was indicated
* `at_iteration` – store the iteration at which the stop indication happened

stores a threshold when to stop looking at the norm of the change of the
optimization variable from within a [`AbstractManoptSolverState`](@ref), i.e `get_iterate(o)`.
For the storage a [`StoreStateAction`](@ref) is used

# Constructor

StopWhenEntryChangeLess(
field::Symbol
distance,
threshold;
storage::StoreStateAction=StoreStateAction([field]),
)

"""
mutable struct StopWhenEntryChangeLess{F,TF,TSSA<:StoreStateAction} <: StoppingCriterion
at_iteration::Int
distance::F
field::Symbol
reason::String
storage::TSSA
threshold::TF
end
function StopWhenEntryChangeLess(
field::Symbol, distance::F, threshold::TF; storage::TSSA=StoreStateAction([field])
) where {F,TF,TSSA<:StoreStateAction}
return StopWhenEntryChangeLess{F,TF,TSSA}(0, distance, field, "", storage, threshold)
end

function (sc::StopWhenEntryChangeLess)(
mp::AbstractManoptProblem, s::AbstractManoptSolverState, i
)
if i == 0 # reset on init
sc.reason = ""
sc.at_iteration = 0
end
if has_storage(sc.storage, sc.field)
old_field_value = get_storage(sc.storage, sc.field)
ε = sc.distance(mp, s, old_field_value, getproperty(s, sc.field))
if (i > 0) && (ε < sc.threshold)
sc.reason = "The algorithm performed a step with a change ($ε) in $(sc.field) less than $(sc.threshold).\n"
sc.at_iteration = i
sc.storage(mp, s, i)
return true
end
end
sc.storage(mp, s, i)
return false
end
function status_summary(sc::StopWhenEntryChangeLess)
has_stopped = length(sc.reason) > 0
s = has_stopped ? "reached" : "not reached"
return "|Δ:$(sc.field)| < $(sc.threshold): $s"
end

"""
update_stopping_criterion!(c::StopWhenEntryChangeLess, :Threshold, v)

Update the minimal cost below which the algorithm shall stop
"""
function update_stopping_criterion!(c::StopWhenEntryChangeLess, ::Val{:Threshold}, v)
c.threshold = v
return c
end
function show(io::IO, c::StopWhenEntryChangeLess)
return print(io, "StopWhenEntryChangeLess\n $(status_summary(c))")
end

@doc raw"""
StopWhenGradientChangeLess <: StoppingCriterion

Expand Down Expand Up @@ -419,31 +504,50 @@ end

A stopping criterion based on the current gradient norm.

# Fields

* `norm` – a function `(M::AbstractManifold, p, X) -> ℝ` that computes a norm of the gradient `X` in the tangent space at `p` on `M``
* `threshold` – the threshold to indicate to stop when the distance is below this value

# Internal fields

* `reason` – store a string reason when the stop was indicated
* `at_iteration` – store the iteration at which the stop indication happened


# Constructor

StopWhenGradientNormLess(ε::Float64)
StopWhenGradientNormLess(ε; norm=(M,p,X) -> norm(M,p,X))

Create a stopping criterion with threshold `ε` for the gradient, that is, this criterion
indicates to stop when [`get_gradient`](@ref) returns a gradient vector of norm less than `ε`.
indicates to stop when [`get_gradient`](@ref) returns a gradient vector of norm less than `ε`,
where the norm to use can be specified in the `norm=` keyword.
"""
mutable struct StopWhenGradientNormLess <: StoppingCriterion
mutable struct StopWhenGradientNormLess{F,TF} <: StoppingCriterion
norm::F
threshold::Float64
reason::String
at_iteration::Int
StopWhenGradientNormLess(ε::Float64) = new(ε, "", 0)
function StopWhenGradientNormLess(ε::TF; norm::F=norm) where {F,TF}
return new{F,TF}(norm, ε, "", 0)
end
end
function (c::StopWhenGradientNormLess)(

function (sc::StopWhenGradientNormLess)(
mp::AbstractManoptProblem, s::AbstractManoptSolverState, i::Int
)
M = get_manifold(mp)
if i == 0 # reset on init
c.reason = ""
c.at_iteration = 0
sc.reason = ""
sc.at_iteration = 0
end
if (norm(M, get_iterate(s), get_gradient(s)) < c.threshold) && (i > 0)
c.reason = "The algorithm reached approximately critical point after $i iterations; the gradient norm ($(norm(M,get_iterate(s),get_gradient(s)))) is less than $(c.threshold).\n"
c.at_iteration = i
return true
if (i > 0)
grad_norm = sc.norm(M, get_iterate(s), get_gradient(s))
if grad_norm < sc.threshold
sc.reason = "The algorithm reached approximately critical point after $i iterations; the gradient norm ($(grad_norm)) is less than $(sc.threshold).\n"
sc.at_iteration = i
return true
end
end
return false
end
Expand Down
33 changes: 9 additions & 24 deletions src/solvers/particle_swarm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,15 @@ function particle_swarm!(
social_weight::Real=1.4,
cognitive_weight::Real=1.4,
stopping_criterion::StoppingCriterion=StopAfterIteration(500) |
StopWhenChangeLess(1e-4),
StopWhenEntryChangeLess(
:swarm,
(p, st, old_swarm, swarm) -> distance(
PowerManifold(get_manifold(p), NestedPowerRepresentation(), length(swarm)),
old_swarm,
swarm,
),
1e-4,
),
kellertuer marked this conversation as resolved.
Show resolved Hide resolved
retraction_method::AbstractRetractionMethod=default_retraction_method(M, eltype(swarm)),
inverse_retraction_method::AbstractInverseRetractionMethod=default_inverse_retraction_method(
M, eltype(swarm)
Expand Down Expand Up @@ -366,26 +374,3 @@ function step_solver!(mp::AbstractManoptProblem, s::ParticleSwarmState, ::Any)
end
end
end
#
# Change not only refers to different iterates (best visited) but to whole `swarm`
# but also lives in the power manifold on M, so we have to adapt StopWhenChangeless
#
function (c::StopWhenChangeLess)(mp::AbstractManoptProblem, s::ParticleSwarmState, i)
if has_storage(c.storage, :Population)
swarm_old = get_storage(c.storage, :Population)
n = length(s.swarm)
d = distance(
PowerManifold(get_manifold(mp), NestedPowerRepresentation(), n),
s.swarm,
swarm_old,
)
if d < c.threshold && i > 0
c.reason = "The algorithm performed a step with a change ($d in the population) less than $(c.threshold).\n"
c.at_iteration = i
c.storage(mp, s, i)
return true
end
end
c.storage(mp, s, i)
return false
end
13 changes: 10 additions & 3 deletions test/plans/test_stopping_criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ end
@testset "Test StopAfter" begin
p = TestStopProblem()
o = TestStopState()
s = StopAfter(Second(1))
s = StopAfter(Millisecond(30))
@test !Manopt.indicates_convergence(s)
@test Manopt.status_summary(s) == "stopped after $(s.threshold):\tnot reached"
@test repr(s) == "StopAfter(Second(1))\n $(Manopt.status_summary(s))"
@test repr(s) == "StopAfter(Millisecond(30))\n $(Manopt.status_summary(s))"
s(p, o, 0) # Start
@test s(p, o, 1) == false
sleep(1.02)
sleep(0.05)
@test s(p, o, 2) == true
@test_throws ErrorException StopAfter(Second(-1))
@test_throws ErrorException update_stopping_criterion!(s, :MaxTime, Second(-1))
Expand Down Expand Up @@ -170,3 +170,10 @@ end
update_stopping_criterion!(s1, :MinStepsize, 1e-1)
@test s1.threshold == 1e-1
end

@testset "Test further setters" begin
swecl = StopWhenEntryChangeLess(:dummy, (p, s, v, w) -> norm(w - v), 1e-5)
@test startswith(repr(swecl), "StopWhenEntryChangeLess\n")
update_stopping_criterion!(swecl, :Threshold, 1e-1)
@test swecl.threshold == 1e-1
end