-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,13 +9,14 @@ using BenchmarkTools | |
using Plots | ||
|
||
""" | ||
StopWhenGradientNormLess <: StoppingCriterion | ||
StopWhenGradientInfNormLess <: StoppingCriterion | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
mateuszbaran
Author
Member
|
||
A stopping criterion based on the current gradient norm. | ||
A stopping criterion based on the current gradient infinity norm in a basis arbitrarily | ||
chosen for each manifold. | ||
# Constructor | ||
StopWhenGradientNormLess(ε::Float64) | ||
StopWhenGradientInfNormLess(ε::Float64) | ||
Create a stopping criterion with threshold `ε` for the gradient, that is, this criterion | ||
indicates to stop when [`get_gradient`](@ref) returns a gradient vector of norm less than `ε`. | ||
|
@@ -116,7 +117,25 @@ function prof() | |
return ProfileView.view() | ||
end | ||
|
||
function generate_cmp(f, g!, f_manopt, g_manopt!) | ||
function manifold_maker(name::Symbol, N, lib::Symbol) | ||
if lib === :Manopt | ||
if name === :Euclidean | ||
return Euclidean(N) | ||
elseif name === :Sphere | ||
return Manifolds.Sphere(N - 1) | ||
end | ||
elseif lib === :Optim | ||
if name === :Euclidean | ||
return Optim.Flat() | ||
elseif name === :Sphere | ||
return Optim.Sphere() | ||
end | ||
else | ||
error("Unknown library: $lib") | ||
end | ||
end | ||
|
||
function generate_cmp(f, g!, f_manopt, g_manopt!; mem_len::Int=2) | ||
plt = plot() | ||
xlabel!("dimension") | ||
ylabel!("time [ms]") | ||
|
@@ -129,55 +148,103 @@ function generate_cmp(f, g!, f_manopt, g_manopt!) | |
ls_hz = LineSearches.HagerZhang() | ||
|
||
gtol = 1e-6 | ||
mem_len = 10 | ||
println("Benchmarking $f for gtol=$gtol") | ||
for N in N_vals | ||
println("Benchmarking for N=$N") | ||
M = Euclidean(N) | ||
method_optim = LBFGS(; m=mem_len, linesearch=ls_hz, manifold=Optim.Flat()) | ||
|
||
x0 = zeros(N) | ||
manopt_sc = StopWhenGradientInfNormLess(gtol) | StopAfterIteration(1000) | ||
bench_manopt = @benchmark quasi_Newton( | ||
$M, | ||
$f_manopt, | ||
$g_manopt!, | ||
$x0; | ||
stepsize=$(Manopt.LineSearchesStepsize(ls_hz)), | ||
evaluation=$(InplaceEvaluation()), | ||
memory_size=$mem_len, | ||
stopping_criterion=$(manopt_sc), | ||
for manifold_name in [:Euclidean, :Sphere] | ||
println("Benchmarking $f for gtol=$gtol on $manifold_name") | ||
for N in N_vals | ||
println("Benchmarking for N=$N") | ||
M = manifold_maker(manifold_name, N, :Manopt) | ||
method_optim = LBFGS(; | ||
m=mem_len, | ||
linesearch=ls_hz, | ||
manifold=manifold_maker(manifold_name, N, :Optim), | ||
) | ||
|
||
x0 = zeros(N) | ||
x0[1] = 1 | ||
manopt_sc = StopWhenGradientInfNormLess(gtol) | StopAfterIteration(1000) | ||
bench_manopt = @benchmark quasi_Newton( | ||
$M, | ||
$f_manopt, | ||
$g_manopt!, | ||
$x0; | ||
stepsize=$(Manopt.LineSearchesStepsize(ls_hz)), | ||
evaluation=$(InplaceEvaluation()), | ||
memory_size=$mem_len, | ||
stopping_criterion=$(manopt_sc), | ||
) | ||
|
||
manopt_state = quasi_Newton( | ||
M, | ||
f_manopt, | ||
g_manopt!, | ||
x0; | ||
stepsize=Manopt.LineSearchesStepsize(ls_hz), | ||
evaluation=InplaceEvaluation(), | ||
return_state=true, | ||
memory_size=mem_len, | ||
stopping_criterion=manopt_sc, | ||
) | ||
manopt_iters = get_count(manopt_state, :Iterations) | ||
push!(times_manopt, median(bench_manopt.times) / 1000) | ||
println("Manopt.jl time: $(median(bench_manopt.times) / 1000) ms") | ||
println("Manopt.jl iterations: $(manopt_iters)") | ||
|
||
options_optim = Optim.Options(; g_tol=gtol) | ||
bench_optim = @benchmark optimize($f, $g!, $x0, $method_optim, $options_optim) | ||
|
||
optim_state = optimize( | ||
f_rosenbrock, g_rosenbrock!, x0, method_optim, options_optim | ||
) | ||
println("Optim.jl time: $(median(bench_optim.times) / 1000) ms") | ||
push!(times_optim, median(bench_optim.times) / 1000) | ||
println("Optim.jl iterations: $(optim_state.iterations)") | ||
end | ||
plot!( | ||
N_vals, times_manopt; label="Manopt.jl ($manifold_name)", xaxis=:log, yaxis=:log | ||
) | ||
|
||
manopt_state = quasi_Newton( | ||
M, | ||
f_manopt, | ||
g_manopt!, | ||
x0; | ||
stepsize=Manopt.LineSearchesStepsize(ls_hz), | ||
evaluation=InplaceEvaluation(), | ||
return_state=true, | ||
memory_size=mem_len, | ||
stopping_criterion=manopt_sc, | ||
plot!( | ||
N_vals, times_optim; label="Optim.jl ($manifold_name)", xaxis=:log, yaxis=:log | ||
) | ||
manopt_iters = get_count(manopt_state, :Iterations) | ||
push!(times_manopt, median(bench_manopt.times) / 1000) | ||
println("Manopt.jl time: $(median(bench_manopt.times) / 1000) ms") | ||
println("Manopt.jl iterations: $(manopt_iters)") | ||
|
||
options_optim = Optim.Options(; g_tol=gtol) | ||
bench_optim = @benchmark optimize($f, $g!, $x0, $method_optim, $options_optim) | ||
|
||
optim_state = optimize(f_rosenbrock, g_rosenbrock!, x0, method_optim, options_optim) | ||
println("Optim.jl time: $(median(bench_optim.times) / 1000) ms") | ||
push!(times_optim, median(bench_optim.times) / 1000) | ||
println("Optim.jl iterations: $(optim_state.iterations)") | ||
end | ||
plot!(N_vals, times_manopt; label="Manopt.jl", xaxis=:log, yaxis=:log) | ||
plot!(N_vals, times_optim; label="Optim.jl", xaxis=:log, yaxis=:log) | ||
xticks!(N_vals, string.(N_vals)) | ||
|
||
return plt | ||
end | ||
|
||
generate_cmp(f_rosenbrock, g_rosenbrock!, f_rosenbrock_manopt, g_rosenbrock_manopt!) | ||
#generate_cmp(f_rosenbrock, g_rosenbrock!, f_rosenbrock_manopt, g_rosenbrock_manopt!) | ||
|
||
function test_case_manopt() | ||
N = 4 | ||
mem_len = 2 | ||
M = Manifolds.Euclidean(N) | ||
ls_hz = LineSearches.HagerZhang() | ||
|
||
x0 = zeros(N) | ||
x0[1] = 1 | ||
manopt_sc = StopWhenGradientInfNormLess(1e-6) | StopAfterIteration(1000) | ||
|
||
return quasi_Newton( | ||
M, | ||
f_rosenbrock_manopt, | ||
g_rosenbrock_manopt!, | ||
x0; | ||
stepsize=Manopt.LineSearchesStepsize(ls_hz), | ||
evaluation=InplaceEvaluation(), | ||
return_state=true, | ||
memory_size=mem_len, | ||
stopping_criterion=manopt_sc, | ||
) | ||
end | ||
|
||
function test_case_optim() | ||
N = 4 | ||
mem_len = 2 | ||
ls_hz = LineSearches.HagerZhang() | ||
method_optim = LBFGS(; m=mem_len, linesearch=ls_hz, manifold=Optim.Flat()) | ||
options_optim = Optim.Options(; g_tol=1e-6) | ||
|
||
x0 = zeros(N) | ||
x0[1] = 1 | ||
optim_state = optimize(f_rosenbrock, g_rosenbrock!, x0, method_optim, options_optim) | ||
return optim_state | ||
end |
We could also extend the StopWhenGradientNormLess to have a
norm=
argument that takes a norm function?