diff --git a/Changelog.md b/Changelog.md index a5b30e3b10..4934dc1e41 100644 --- a/Changelog.md +++ b/Changelog.md @@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Allow the `message=` of the `DebugIfEntry` debug action to contain a format element to print the field in the message as well. +## [0.4.51] January 30, 2024 + +### Added + +* A `StopWhenSubgradientNormLess` stopping criterion for subgradient-based optimization. + ## [0.4.50] January 26, 2024 ### Fixed diff --git a/Project.toml b/Project.toml index 13424a8dfb..71e6be3120 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manopt" uuid = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5" authors = ["Ronny Bergmann "] -version = "0.4.50" +version = "0.4.51" [deps] ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" diff --git a/src/Manopt.jl b/src/Manopt.jl index b2c20a6061..3e7901d19c 100644 --- a/src/Manopt.jl +++ b/src/Manopt.jl @@ -448,6 +448,7 @@ export StopAfter, StopWhenPopulationConcentrated, StopWhenSmallerOrEqual, StopWhenStepsizeLess, + StopWhenSubgradientNormLess, StopWhenTrustRegionIsExceeded export get_active_stopping_criteria, get_stopping_criteria, get_reason, get_stopping_criterion diff --git a/src/plans/stopping_criterion.jl b/src/plans/stopping_criterion.jl index b53f29efd9..3d95df90f0 100644 --- a/src/plans/stopping_criterion.jl +++ b/src/plans/stopping_criterion.jl @@ -173,6 +173,57 @@ function update_stopping_criterion!(c::StopAfterIteration, ::Val{:MaxIteration}, return c end +""" + StopWhenSubgradientNormLess <: StoppingCriterion + +A stopping criterion based on the current subgradient norm. + +# Constructor + + StopWhenSubgradientNormLess(ε::Float64) + +Create a stopping criterion with threshold `ε` for the subgradient, that is, this criterion +indicates to stop when [`get_subgradient`](@ref) returns a subgradient vector of norm less than `ε`. +""" +mutable struct StopWhenSubgradientNormLess <: StoppingCriterion + threshold::Float64 + reason::String + StopWhenSubgradientNormLess(ε::Float64) = new(ε, "") +end +function (c::StopWhenSubgradientNormLess)( + mp::AbstractManoptProblem, s::AbstractManoptSolverState, i::Int +) + M = get_manifold(mp) + (i == 0) && (c.reason = "") # reset on init + if (norm(M, get_iterate(s), get_subgradient(s)) < c.threshold) && (i > 0) + c.reason = "The algorithm reached approximately critical point after $i iterations; the subgradient norm ($(norm(M,get_iterate(s),get_subgradient(s)))) is less than $(c.threshold).\n" + return true + end + return false +end +function status_summary(c::StopWhenSubgradientNormLess) + has_stopped = length(c.reason) > 0 + s = has_stopped ? "reached" : "not reached" + return "|subgrad f| < $(c.threshold): $s" +end +indicates_convergence(c::StopWhenSubgradientNormLess) = true +function show(io::IO, c::StopWhenSubgradientNormLess) + return print( + io, "StopWhenSubgradientNormLess($(c.threshold))\n $(status_summary(c))" + ) +end +""" + update_stopping_criterion!(c::StopWhenSubgradientNormLess, :MinSubgradNorm, v::Float64) + +Update the minimal subgradient norm when an algorithm shall stop +""" +function update_stopping_criterion!( + c::StopWhenSubgradientNormLess, ::Val{:MinSubgradNorm}, v::Float64 +) + c.threshold = v + return c +end + """ StopWhenChangeLess <: StoppingCriterion diff --git a/test/plans/test_stopping_criteria.jl b/test/plans/test_stopping_criteria.jl index f15e191eec..474313604e 100644 --- a/test/plans/test_stopping_criteria.jl +++ b/test/plans/test_stopping_criteria.jl @@ -74,6 +74,9 @@ end c = StopWhenGradientNormLess(1e-6) sc = "StopWhenGradientNormLess(1.0e-6)\n $(Manopt.status_summary(c))" @test repr(c) == sc + c2 = StopWhenSubgradientNormLess(1e-6) + sc2 = "StopWhenSubgradientNormLess(1.0e-6)\n $(Manopt.status_summary(c2))" + @test repr(c2) == sc2 d = StopWhenAll(a, b, c) @test typeof(d) === typeof(a & b & c) @test typeof(d) === typeof(a & (b & c))