From a5b21a74985b083d9cce94e0864e6e496132884e Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Tue, 16 Jan 2024 18:18:23 +0100 Subject: [PATCH 01/11] INtroduce a stopping criterion to track the change of arbitrary fields. Check that we change the PSO stopping criterion accordingly (and not overwrite iterate change) --- src/plans/stopping_criterion.jl | 82 +++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/src/plans/stopping_criterion.jl b/src/plans/stopping_criterion.jl index a254a3f71d..1cdb945794 100644 --- a/src/plans/stopping_criterion.jl +++ b/src/plans/stopping_criterion.jl @@ -316,6 +316,88 @@ function update_stopping_criterion!(c::StopWhenCostLess, ::Val{:MinCost}, v) return c end +@doc raw""" + StopWhenEntryChangeLess + +Evaluate whether a certain fields change is less than a certain threshold + +## Fields + +* `field` – a symbol adressing the corresponding field in a certain subtype of [`AbstractManoptSolverState`](@ref) + to track +* `distance` – a function `(problem, state, v1, v2) -> R` that computes the distance between two possible values of the `field` +* `storage` – a [`StoreStateAction`](@ref) to store the previous value of the `field` +* `threshold` – the threshold to indicate to stop when the distance is below this value + +# Internal fields + +* `reason` – store a string reason when the stop was indicated +* `at_iteration` – store the iteration at which the stop indication happened + +stores a threshold when to stop looking at the norm of the change of the +optimization variable from within a [`AbstractManoptSolverState`](@ref), i.e `get_iterate(o)`. +For the storage a [`StoreStateAction`](@ref) is used + +# Constructor + + StopWhenEntryChangeLess( + field::Symbol + distance, + threshold; + storage::StoreStateAction=StoreStateAction([field]), + ) + +""" +mutable struct StopWhenEntryChangeLess{F,TI,TF,TSSA<:StoreStateAction} <: StoppingCriterion + at_iteration::Int + distance::F + field::Symbol + reason::String + storage::TSSA + threshold::TF +end +function StopWhenEntryChangeLess( + field::Symbol, distance::F, threshold::TF; storage::TSSA=StoreStateAction([field]) +) where {F,TF,TSSA<:StoreStateAction} + return StopWhenEntryChangeLess{F,TF,TSSA}(0, distance, field, "", storage, threshold) +end + +function (sc::StopWhenEntryChangeLess)( + mp::AbstractManoptProblem, s::AbstractManoptSolverState, i +) + if i == 0 # reset on init + sc.reason = "" + sc.at_iteration = 0 + end + if has_storage(sc.storage, sc.field) + old_field_value = get_storage(sc.storage, sc.field) + ε = sc.distance(mp, s, old_field_value, getproperty(st, d.field)) + if (i > 0) && (ε < sc.threshold) + sc.reason = "The algorithm performed a step with a change ($ε) in $(sc.field) less than $(sc.threshold).\n" + sc.at_iteration = i + sc.storage(mp, s, i) + return true + end + end + sc.storage(mp, s, i) + return false +end +function status_summary(c::StopWhenEntryChangeLess) + has_stopped = length(c.reason) > 0 + s = has_stopped ? "reached" : "not reached" + return "|Δ:$(field)| < $(c.threshold): $s" +end + +""" + update_stopping_criterion!(c::StopWhenEntryChangeLess, :Threshold, v) + +Update the minimal cost below which the algorithm shall stop +""" +function update_stopping_criterion!(c::StopWhenEntryChangeLess, ::Val{:Threshold}, v) + c.threshold = v + return c +end + @doc raw""" StopWhenGradientChangeLess <: StoppingCriterion From 634511b452192940cc9f458563d2ebc247a8c6df Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Tue, 16 Jan 2024 18:18:47 +0100 Subject: [PATCH 02/11] generalize the gradient change stopping criterion to accept arbitrary norms. --- src/plans/stopping_criterion.jl | 38 ++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/src/plans/stopping_criterion.jl b/src/plans/stopping_criterion.jl index 1cdb945794..a3876d711b 100644 --- a/src/plans/stopping_criterion.jl +++ b/src/plans/stopping_criterion.jl @@ -501,30 +501,48 @@ end A stopping criterion based on the current gradient norm. +# Fields + +* `norm` – a function `(M, p, X) -> R` that computes a norm of the gradient `X` in the tangent space at `p` on `M`` +* `threshold` – the threshold to indicate to stop when the distance is below this value + +# Internal fields + +* `reason` – store a string reason when the stop was indicated +* `at_iteration` – store the iteration at which the stop indication happened + + # Constructor - StopWhenGradientNormLess(ε::Float64) + StopWhenGradientNormLess(ε; norm=(M,p,X) -> norm(M,p,X)) Create a stopping criterion with threshold `ε` for the gradient, that is, this criterion -indicates to stop when [`get_gradient`](@ref) returns a gradient vector of norm less than `ε`. +indicates to stop when [`get_gradient`](@ref) returns a gradient vector of norm less than `ε`, +where the norm to use can be specified in the `norm=` keyword. """ -mutable struct StopWhenGradientNormLess <: StoppingCriterion +mutable struct StopWhenGradientNormLess{F,TF} <: StoppingCriterion + norm::F threshold::Float64 reason::String at_iteration::Int - StopWhenGradientNormLess(ε::Float64) = new(ε, "", 0) + function StopWhenGradientNormLess( + ε::TF; norm::F=(M, p, X) -> norm(M, p, X) + ) where {F,TF} + return new{F,TF}(norm, ε, "", 0) + end end -function (c::StopWhenGradientNormLess)( + +function (sc::StopWhenGradientNormLess)( mp::AbstractManoptProblem, s::AbstractManoptSolverState, i::Int ) M = get_manifold(mp) if i == 0 # reset on init - c.reason = "" - c.at_iteration = 0 + sc.reason = "" + sc.at_iteration = 0 end - if (norm(M, get_iterate(s), get_gradient(s)) < c.threshold) && (i > 0) - c.reason = "The algorithm reached approximately critical point after $i iterations; the gradient norm ($(norm(M,get_iterate(s),get_gradient(s)))) is less than $(c.threshold).\n" - c.at_iteration = i + if (i > 0) && (sc.norm(M, get_iterate(s), get_gradient(s)) < sc.threshold) + sc.reason = "The algorithm reached approximately critical point after $i iterations; the gradient norm ($(sc.norm(M,get_iterate(s),get_gradient(s)))) is less than $(sc.threshold).\n" + sc.at_iteration = i return true end return false From 0489058f04f9e621d6f002972d56890c590693a2 Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Tue, 16 Jan 2024 18:23:08 +0100 Subject: [PATCH 03/11] Update changelog. --- Changelog.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/Changelog.md b/Changelog.md index f28ebc36cc..be8cc1edeb 100644 --- a/Changelog.md +++ b/Changelog.md @@ -5,7 +5,14 @@ All notable Changes to the Julia package `Manopt.jl` will be documented in this The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.4.48] +## [0.4.49] January 17, 2024 + +### Added + +* A `StopWhenEntryChangeLess` to be able to stop on arbitrary small changes of specific fields +* generalises `StopWhenGradientNormLess` to accept arbitrary `norm=` functions + +## [0.4.48] January 16, 2024 ### Fixed From 2609d7099b8ef269892e85ad69ca85390a98f47c Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Wed, 17 Jan 2024 17:23:42 +0100 Subject: [PATCH 04/11] Switch PSO default to use a :swarm change instead and adapt tests. --- Changelog.md | 2 ++ src/plans/stopping_criterion.jl | 10 +++++----- src/solvers/particle_swarm.jl | 33 +++++++++------------------------ 3 files changed, 16 insertions(+), 29 deletions(-) diff --git a/Changelog.md b/Changelog.md index be8cc1edeb..c6254a3856 100644 --- a/Changelog.md +++ b/Changelog.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * A `StopWhenEntryChangeLess` to be able to stop on arbitrary small changes of specific fields * generalises `StopWhenGradientNormLess` to accept arbitrary `norm=` functions +* refactor the default in `particle_swarm` to no longer “misuse” the iteration change check, + but actually the new one one the `:swarm` entry ## [0.4.48] January 16, 2024 diff --git a/src/plans/stopping_criterion.jl b/src/plans/stopping_criterion.jl index a3876d711b..9d801bcea2 100644 --- a/src/plans/stopping_criterion.jl +++ b/src/plans/stopping_criterion.jl @@ -348,7 +348,7 @@ For the storage a [`StoreStateAction`](@ref) is used ) """ -mutable struct StopWhenEntryChangeLess{F,TI,TF,TSSA<:StoreStateAction} <: StoppingCriterion +mutable struct StopWhenEntryChangeLess{F,TF,TSSA<:StoreStateAction} <: StoppingCriterion at_iteration::Int distance::F field::Symbol @@ -371,7 +371,7 @@ function (sc::StopWhenEntryChangeLess)( end if has_storage(sc.storage, sc.field) old_field_value = get_storage(sc.storage, sc.field) - ε = sc.distance(mp, s, old_field_value, getproperty(st, d.field)) + ε = sc.distance(mp, s, old_field_value, getproperty(s, sc.field)) if (i > 0) && (ε < sc.threshold) sc.reason = "The algorithm performed a step with a change ($ε) in $(sc.field) less than $(sc.threshold).\n" sc.at_iteration = i @@ -382,10 +382,10 @@ function (sc::StopWhenEntryChangeLess)( sc.storage(mp, s, i) return false end -function status_summary(c::StopWhenEntryChangeLess) - has_stopped = length(c.reason) > 0 +function status_summary(sc::StopWhenEntryChangeLess) + has_stopped = length(sc.reason) > 0 s = has_stopped ? "reached" : "not reached" - return "|Δ:$(field)| < $(c.threshold): $s" + return "|Δ:$(sc.field)| < $(sc.threshold): $s" end """ diff --git a/src/solvers/particle_swarm.jl b/src/solvers/particle_swarm.jl index 41a0de6d77..9b65c3f237 100644 --- a/src/solvers/particle_swarm.jl +++ b/src/solvers/particle_swarm.jl @@ -297,7 +297,15 @@ function particle_swarm!( social_weight::Real=1.4, cognitive_weight::Real=1.4, stopping_criterion::StoppingCriterion=StopAfterIteration(500) | - StopWhenChangeLess(1e-4), + StopWhenEntryChangeLess( + :swarm, + (p, st, old_swarm, swarm) -> distance( + PowerManifold(get_manifold(p), NestedPowerRepresentation(), length(swarm)), + old_swarm, + swarm, + ), + 1e-4, + ), retraction_method::AbstractRetractionMethod=default_retraction_method(M, eltype(swarm)), inverse_retraction_method::AbstractInverseRetractionMethod=default_inverse_retraction_method( M, eltype(swarm) @@ -366,26 +374,3 @@ function step_solver!(mp::AbstractManoptProblem, s::ParticleSwarmState, ::Any) end end end -# -# Change not only refers to different iterates (best visited) but to whole `swarm` -# but also lives in the power manifold on M, so we have to adapt StopWhenChangeless -# -function (c::StopWhenChangeLess)(mp::AbstractManoptProblem, s::ParticleSwarmState, i) - if has_storage(c.storage, :Population) - swarm_old = get_storage(c.storage, :Population) - n = length(s.swarm) - d = distance( - PowerManifold(get_manifold(mp), NestedPowerRepresentation(), n), - s.swarm, - swarm_old, - ) - if d < c.threshold && i > 0 - c.reason = "The algorithm performed a step with a change ($d in the population) less than $(c.threshold).\n" - c.at_iteration = i - c.storage(mp, s, i) - return true - end - end - c.storage(mp, s, i) - return false -end From f4019a4742c544d7e9b32c832a3fd2a2d51170b1 Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Wed, 17 Jan 2024 17:55:24 +0100 Subject: [PATCH 05/11] Improve test coverage. --- src/Manopt.jl | 1 + src/plans/stopping_criterion.jl | 3 +++ test/plans/test_stopping_criteria.jl | 13 ++++++++++--- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/Manopt.jl b/src/Manopt.jl index 5898fec2e6..b2c20a6061 100644 --- a/src/Manopt.jl +++ b/src/Manopt.jl @@ -440,6 +440,7 @@ export StopAfter, StopWhenChangeLess, StopWhenCostLess, StopWhenCurvatureIsNegative, + StopWhenEntryChangeLess, StopWhenGradientChangeLess, StopWhenGradientNormLess, StopWhenFirstOrderProgress, diff --git a/src/plans/stopping_criterion.jl b/src/plans/stopping_criterion.jl index 9d801bcea2..60bf00341a 100644 --- a/src/plans/stopping_criterion.jl +++ b/src/plans/stopping_criterion.jl @@ -397,6 +397,9 @@ function update_stopping_criterion!(c::StopWhenEntryChangeLess, ::Val{:Threshold c.threshold = v return c end +function show(io::IO, c::StopWhenEntryChangeLess) + return print(io, "StopWhenEntryChangeLess\n $(status_summary(c))") +end @doc raw""" StopWhenGradientChangeLess <: StoppingCriterion diff --git a/test/plans/test_stopping_criteria.jl b/test/plans/test_stopping_criteria.jl index f0267dd95d..f15e191eec 100644 --- a/test/plans/test_stopping_criteria.jl +++ b/test/plans/test_stopping_criteria.jl @@ -50,13 +50,13 @@ end @testset "Test StopAfter" begin p = TestStopProblem() o = TestStopState() - s = StopAfter(Second(1)) + s = StopAfter(Millisecond(30)) @test !Manopt.indicates_convergence(s) @test Manopt.status_summary(s) == "stopped after $(s.threshold):\tnot reached" - @test repr(s) == "StopAfter(Second(1))\n $(Manopt.status_summary(s))" + @test repr(s) == "StopAfter(Millisecond(30))\n $(Manopt.status_summary(s))" s(p, o, 0) # Start @test s(p, o, 1) == false - sleep(1.02) + sleep(0.05) @test s(p, o, 2) == true @test_throws ErrorException StopAfter(Second(-1)) @test_throws ErrorException update_stopping_criterion!(s, :MaxTime, Second(-1)) @@ -170,3 +170,10 @@ end update_stopping_criterion!(s1, :MinStepsize, 1e-1) @test s1.threshold == 1e-1 end + +@testset "Test further setters" begin + swecl = StopWhenEntryChangeLess(:dummy, (p, s, v, w) -> norm(w - v), 1e-5) + @test startswith(repr(swecl), "StopWhenEntryChangeLess\n") + update_stopping_criterion!(swecl, :Threshold, 1e-1) + @test swecl.threshold == 1e-1 +end From f19cc8a33c9fd8b07b1b3fd6a8f4995a23c209d4 Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Wed, 17 Jan 2024 18:48:36 +0100 Subject: [PATCH 06/11] Apply suggestions from code review Co-authored-by: Mateusz Baran --- src/plans/stopping_criterion.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/plans/stopping_criterion.jl b/src/plans/stopping_criterion.jl index 60bf00341a..983b231228 100644 --- a/src/plans/stopping_criterion.jl +++ b/src/plans/stopping_criterion.jl @@ -506,7 +506,7 @@ A stopping criterion based on the current gradient norm. # Fields -* `norm` – a function `(M, p, X) -> R` that computes a norm of the gradient `X` in the tangent space at `p` on `M`` +* `norm` – a function `(M::AbstractManifold, p, X) -> ℝ` that computes a norm of the gradient `X` in the tangent space at `p` on `M`` * `threshold` – the threshold to indicate to stop when the distance is below this value # Internal fields @@ -529,7 +529,7 @@ mutable struct StopWhenGradientNormLess{F,TF} <: StoppingCriterion reason::String at_iteration::Int function StopWhenGradientNormLess( - ε::TF; norm::F=(M, p, X) -> norm(M, p, X) + ε::TF; norm::F=norm ) where {F,TF} return new{F,TF}(norm, ε, "", 0) end From caa87df2d47bf5ef110a249ec95255a1e8d56549 Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Wed, 17 Jan 2024 18:52:16 +0100 Subject: [PATCH 07/11] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/plans/stopping_criterion.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/plans/stopping_criterion.jl b/src/plans/stopping_criterion.jl index 983b231228..3687cea991 100644 --- a/src/plans/stopping_criterion.jl +++ b/src/plans/stopping_criterion.jl @@ -528,9 +528,7 @@ mutable struct StopWhenGradientNormLess{F,TF} <: StoppingCriterion threshold::Float64 reason::String at_iteration::Int - function StopWhenGradientNormLess( - ε::TF; norm::F=norm - ) where {F,TF} + function StopWhenGradientNormLess(ε::TF; norm::F=norm) where {F,TF} return new{F,TF}(norm, ε, "", 0) end end From 0d591eb859a11cc7c03f69c81bb2b846b01fa371 Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Thu, 18 Jan 2024 09:48:13 +0100 Subject: [PATCH 08/11] Update src/plans/stopping_criterion.jl Co-authored-by: Mateusz Baran --- src/plans/stopping_criterion.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/plans/stopping_criterion.jl b/src/plans/stopping_criterion.jl index 3687cea991..adf2d41310 100644 --- a/src/plans/stopping_criterion.jl +++ b/src/plans/stopping_criterion.jl @@ -541,10 +541,13 @@ function (sc::StopWhenGradientNormLess)( sc.reason = "" sc.at_iteration = 0 end - if (i > 0) && (sc.norm(M, get_iterate(s), get_gradient(s)) < sc.threshold) - sc.reason = "The algorithm reached approximately critical point after $i iterations; the gradient norm ($(sc.norm(M,get_iterate(s),get_gradient(s)))) is less than $(sc.threshold).\n" - sc.at_iteration = i - return true + if (i > 0) + grad_norm = sc.norm(M, get_iterate(s), get_gradient(s)) + if grad_norm < sc.threshold) + sc.reason = "The algorithm reached approximately critical point after $i iterations; the gradient norm ($(grad_norm)) is less than $(sc.threshold).\n" + sc.at_iteration = i + return true + end end return false end From 687ecda7f617ade240400ec826930bda16a2846b Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 18 Jan 2024 11:20:05 +0100 Subject: [PATCH 09/11] Update src/plans/stopping_criterion.jl --- src/plans/stopping_criterion.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plans/stopping_criterion.jl b/src/plans/stopping_criterion.jl index adf2d41310..3c30b0a933 100644 --- a/src/plans/stopping_criterion.jl +++ b/src/plans/stopping_criterion.jl @@ -543,7 +543,7 @@ function (sc::StopWhenGradientNormLess)( end if (i > 0) grad_norm = sc.norm(M, get_iterate(s), get_gradient(s)) - if grad_norm < sc.threshold) + if grad_norm < sc.threshold sc.reason = "The algorithm reached approximately critical point after $i iterations; the gradient norm ($(grad_norm)) is less than $(sc.threshold).\n" sc.at_iteration = i return true From 747efe6a53002fd0e3e0d72b29f2fe1d8d9c384f Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 18 Jan 2024 11:21:58 +0100 Subject: [PATCH 10/11] Update src/plans/stopping_criterion.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/plans/stopping_criterion.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plans/stopping_criterion.jl b/src/plans/stopping_criterion.jl index 3c30b0a933..b53f29efd9 100644 --- a/src/plans/stopping_criterion.jl +++ b/src/plans/stopping_criterion.jl @@ -541,7 +541,7 @@ function (sc::StopWhenGradientNormLess)( sc.reason = "" sc.at_iteration = 0 end - if (i > 0) + if (i > 0) grad_norm = sc.norm(M, get_iterate(s), get_gradient(s)) if grad_norm < sc.threshold sc.reason = "The algorithm reached approximately critical point after $i iterations; the gradient norm ($(grad_norm)) is less than $(sc.threshold).\n" From 18efdc9d2325e44322b141b1c9e6a52dc8b1b015 Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Thu, 18 Jan 2024 18:34:14 +0100 Subject: [PATCH 11/11] Fix changelog. --- Changelog.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Changelog.md b/Changelog.md index c6254a3856..9436a614b3 100644 --- a/Changelog.md +++ b/Changelog.md @@ -5,7 +5,7 @@ All notable Changes to the Julia package `Manopt.jl` will be documented in this The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.4.49] January 17, 2024 +## [0.4.49] January 18, 2024 ### Added