diff --git a/Changelog.md b/Changelog.md index f28ebc36cc..9436a614b3 100644 --- a/Changelog.md +++ b/Changelog.md @@ -5,7 +5,16 @@ All notable Changes to the Julia package `Manopt.jl` will be documented in this The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.4.48] +## [0.4.49] January 18, 2024 + +### Added + +* A `StopWhenEntryChangeLess` to be able to stop on arbitrary small changes of specific fields +* generalises `StopWhenGradientNormLess` to accept arbitrary `norm=` functions +* refactor the default in `particle_swarm` to no longer “misuse” the iteration change check, + but actually the new one one the `:swarm` entry + +## [0.4.48] January 16, 2024 ### Fixed diff --git a/src/Manopt.jl b/src/Manopt.jl index 5898fec2e6..b2c20a6061 100644 --- a/src/Manopt.jl +++ b/src/Manopt.jl @@ -440,6 +440,7 @@ export StopAfter, StopWhenChangeLess, StopWhenCostLess, StopWhenCurvatureIsNegative, + StopWhenEntryChangeLess, StopWhenGradientChangeLess, StopWhenGradientNormLess, StopWhenFirstOrderProgress, diff --git a/src/plans/stopping_criterion.jl b/src/plans/stopping_criterion.jl index a254a3f71d..b53f29efd9 100644 --- a/src/plans/stopping_criterion.jl +++ b/src/plans/stopping_criterion.jl @@ -316,6 +316,91 @@ function update_stopping_criterion!(c::StopWhenCostLess, ::Val{:MinCost}, v) return c end +@doc raw""" + StopWhenEntryChangeLess + +Evaluate whether a certain fields change is less than a certain threshold + +## Fields + +* `field` – a symbol adressing the corresponding field in a certain subtype of [`AbstractManoptSolverState`](@ref) + to track +* `distance` – a function `(problem, state, v1, v2) -> R` that computes the distance between two possible values of the `field` +* `storage` – a [`StoreStateAction`](@ref) to store the previous value of the `field` +* `threshold` – the threshold to indicate to stop when the distance is below this value + +# Internal fields + +* `reason` – store a string reason when the stop was indicated +* `at_iteration` – store the iteration at which the stop indication happened + +stores a threshold when to stop looking at the norm of the change of the +optimization variable from within a [`AbstractManoptSolverState`](@ref), i.e `get_iterate(o)`. +For the storage a [`StoreStateAction`](@ref) is used + +# Constructor + + StopWhenEntryChangeLess( + field::Symbol + distance, + threshold; + storage::StoreStateAction=StoreStateAction([field]), + ) + +""" +mutable struct StopWhenEntryChangeLess{F,TF,TSSA<:StoreStateAction} <: StoppingCriterion + at_iteration::Int + distance::F + field::Symbol + reason::String + storage::TSSA + threshold::TF +end +function StopWhenEntryChangeLess( + field::Symbol, distance::F, threshold::TF; storage::TSSA=StoreStateAction([field]) +) where {F,TF,TSSA<:StoreStateAction} + return StopWhenEntryChangeLess{F,TF,TSSA}(0, distance, field, "", storage, threshold) +end + +function (sc::StopWhenEntryChangeLess)( + mp::AbstractManoptProblem, s::AbstractManoptSolverState, i +) + if i == 0 # reset on init + sc.reason = "" + sc.at_iteration = 0 + end + if has_storage(sc.storage, sc.field) + old_field_value = get_storage(sc.storage, sc.field) + ε = sc.distance(mp, s, old_field_value, getproperty(s, sc.field)) + if (i > 0) && (ε < sc.threshold) + sc.reason = "The algorithm performed a step with a change ($ε) in $(sc.field) less than $(sc.threshold).\n" + sc.at_iteration = i + sc.storage(mp, s, i) + return true + end + end + sc.storage(mp, s, i) + return false +end +function status_summary(sc::StopWhenEntryChangeLess) + has_stopped = length(sc.reason) > 0 + s = has_stopped ? "reached" : "not reached" + return "|Δ:$(sc.field)| < $(sc.threshold): $s" +end + +""" + update_stopping_criterion!(c::StopWhenEntryChangeLess, :Threshold, v) + +Update the minimal cost below which the algorithm shall stop +""" +function update_stopping_criterion!(c::StopWhenEntryChangeLess, ::Val{:Threshold}, v) + c.threshold = v + return c +end +function show(io::IO, c::StopWhenEntryChangeLess) + return print(io, "StopWhenEntryChangeLess\n $(status_summary(c))") +end + @doc raw""" StopWhenGradientChangeLess <: StoppingCriterion @@ -419,31 +504,50 @@ end A stopping criterion based on the current gradient norm. +# Fields + +* `norm` – a function `(M::AbstractManifold, p, X) -> ℝ` that computes a norm of the gradient `X` in the tangent space at `p` on `M`` +* `threshold` – the threshold to indicate to stop when the distance is below this value + +# Internal fields + +* `reason` – store a string reason when the stop was indicated +* `at_iteration` – store the iteration at which the stop indication happened + + # Constructor - StopWhenGradientNormLess(ε::Float64) + StopWhenGradientNormLess(ε; norm=(M,p,X) -> norm(M,p,X)) Create a stopping criterion with threshold `ε` for the gradient, that is, this criterion -indicates to stop when [`get_gradient`](@ref) returns a gradient vector of norm less than `ε`. +indicates to stop when [`get_gradient`](@ref) returns a gradient vector of norm less than `ε`, +where the norm to use can be specified in the `norm=` keyword. """ -mutable struct StopWhenGradientNormLess <: StoppingCriterion +mutable struct StopWhenGradientNormLess{F,TF} <: StoppingCriterion + norm::F threshold::Float64 reason::String at_iteration::Int - StopWhenGradientNormLess(ε::Float64) = new(ε, "", 0) + function StopWhenGradientNormLess(ε::TF; norm::F=norm) where {F,TF} + return new{F,TF}(norm, ε, "", 0) + end end -function (c::StopWhenGradientNormLess)( + +function (sc::StopWhenGradientNormLess)( mp::AbstractManoptProblem, s::AbstractManoptSolverState, i::Int ) M = get_manifold(mp) if i == 0 # reset on init - c.reason = "" - c.at_iteration = 0 + sc.reason = "" + sc.at_iteration = 0 end - if (norm(M, get_iterate(s), get_gradient(s)) < c.threshold) && (i > 0) - c.reason = "The algorithm reached approximately critical point after $i iterations; the gradient norm ($(norm(M,get_iterate(s),get_gradient(s)))) is less than $(c.threshold).\n" - c.at_iteration = i - return true + if (i > 0) + grad_norm = sc.norm(M, get_iterate(s), get_gradient(s)) + if grad_norm < sc.threshold + sc.reason = "The algorithm reached approximately critical point after $i iterations; the gradient norm ($(grad_norm)) is less than $(sc.threshold).\n" + sc.at_iteration = i + return true + end end return false end diff --git a/src/solvers/particle_swarm.jl b/src/solvers/particle_swarm.jl index 41a0de6d77..9b65c3f237 100644 --- a/src/solvers/particle_swarm.jl +++ b/src/solvers/particle_swarm.jl @@ -297,7 +297,15 @@ function particle_swarm!( social_weight::Real=1.4, cognitive_weight::Real=1.4, stopping_criterion::StoppingCriterion=StopAfterIteration(500) | - StopWhenChangeLess(1e-4), + StopWhenEntryChangeLess( + :swarm, + (p, st, old_swarm, swarm) -> distance( + PowerManifold(get_manifold(p), NestedPowerRepresentation(), length(swarm)), + old_swarm, + swarm, + ), + 1e-4, + ), retraction_method::AbstractRetractionMethod=default_retraction_method(M, eltype(swarm)), inverse_retraction_method::AbstractInverseRetractionMethod=default_inverse_retraction_method( M, eltype(swarm) @@ -366,26 +374,3 @@ function step_solver!(mp::AbstractManoptProblem, s::ParticleSwarmState, ::Any) end end end -# -# Change not only refers to different iterates (best visited) but to whole `swarm` -# but also lives in the power manifold on M, so we have to adapt StopWhenChangeless -# -function (c::StopWhenChangeLess)(mp::AbstractManoptProblem, s::ParticleSwarmState, i) - if has_storage(c.storage, :Population) - swarm_old = get_storage(c.storage, :Population) - n = length(s.swarm) - d = distance( - PowerManifold(get_manifold(mp), NestedPowerRepresentation(), n), - s.swarm, - swarm_old, - ) - if d < c.threshold && i > 0 - c.reason = "The algorithm performed a step with a change ($d in the population) less than $(c.threshold).\n" - c.at_iteration = i - c.storage(mp, s, i) - return true - end - end - c.storage(mp, s, i) - return false -end diff --git a/test/plans/test_stopping_criteria.jl b/test/plans/test_stopping_criteria.jl index f0267dd95d..f15e191eec 100644 --- a/test/plans/test_stopping_criteria.jl +++ b/test/plans/test_stopping_criteria.jl @@ -50,13 +50,13 @@ end @testset "Test StopAfter" begin p = TestStopProblem() o = TestStopState() - s = StopAfter(Second(1)) + s = StopAfter(Millisecond(30)) @test !Manopt.indicates_convergence(s) @test Manopt.status_summary(s) == "stopped after $(s.threshold):\tnot reached" - @test repr(s) == "StopAfter(Second(1))\n $(Manopt.status_summary(s))" + @test repr(s) == "StopAfter(Millisecond(30))\n $(Manopt.status_summary(s))" s(p, o, 0) # Start @test s(p, o, 1) == false - sleep(1.02) + sleep(0.05) @test s(p, o, 2) == true @test_throws ErrorException StopAfter(Second(-1)) @test_throws ErrorException update_stopping_criterion!(s, :MaxTime, Second(-1)) @@ -170,3 +170,10 @@ end update_stopping_criterion!(s1, :MinStepsize, 1e-1) @test s1.threshold == 1e-1 end + +@testset "Test further setters" begin + swecl = StopWhenEntryChangeLess(:dummy, (p, s, v, w) -> norm(w - v), 1e-5) + @test startswith(repr(swecl), "StopWhenEntryChangeLess\n") + update_stopping_criterion!(swecl, :Threshold, 1e-1) + @test swecl.threshold == 1e-1 +end