Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Sep 29, 2022
1 parent c41a5bf commit 9238530
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonconvexMMA"
uuid = "d3d89cbb-4ecd-4604-818d-8d1ff343e4da"
authors = ["Mohamed Tarek <[email protected]> and contributors"]
version = "0.1.5"
version = "1.0.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
17 changes: 9 additions & 8 deletions src/mma_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ A struct that stores all the options of the MMA algorithms. Th following are the
- `s_decr`: defined in the original [`MMA02`](@ref) paper.
- `store_trace`: if true, a trace will be stored.
- `dual_options`: the options passed to the dual optimizer from [`Optim.jl`](https://github.com/JuliaNLSolvers/Optim.jl).
- `convcriteria`: an instance of [`ConvergenceCriteria`](@ref) that specifies the convergence criteria of the MMA algorithm.
- `verbose`: true/false, when true prints convergence statistics.
"""
@with_kw mutable struct MMAOptions{T, Ttol <: Tolerance, TSubOptions <: Optim.Options}
@with_kw mutable struct MMAOptions{T, Ttol <: Tolerance, TSubOptions <: Optim.Options, TC <: ConvergenceCriteria}
maxiter::Int = 1000
outer_maxiter::Int = 10^8
maxinner::Int = 10
Expand All @@ -60,6 +62,8 @@ A struct that stores all the options of the MMA algorithms. Th following are the
iterations = 1000,
outer_iterations=1000,
)
convcriteria::TC = KKTCriteria()
verbose::Bool = true
end

"""
Expand All @@ -83,7 +87,6 @@ A struct that stores all the intermediate states and memory allocations needed f
- `ρ`: the `ρ` parameter as explained in [`MMAApprox`](@ref).
- `tempx`: a temporary vector used to store the 2nd previous primal solution.
- `options`: an instance of [`MMAOptions`](@ref) that resembles the options of the MMA algorithm.
- `convcriteria`: an instance of [`ConvergenceCriteria`](@ref) that specifies the convergence criteria of the MMA algorithm.
- `callback`: a function that is called on `solution` in every iteration of the algorithm. This can be used to store information about the optimization process.
- `optimizer`: an instance of [`AbstractOptimizer`](@ref) such as `MMA87()` or `MMA02()` that specifies the variant of MMA used to optimize the model.
- `suboptimizer`: the dual optimization algorithm used to optimize the barrier problem. This should be an [`Optim.jl`](https://github.com/JuliaNLSolvers/Optim.jl) optimizer.
Expand All @@ -100,7 +103,6 @@ A struct that stores all the intermediate states and memory allocations needed f
σ::AbstractVector
ρ::AbstractVector
tempx::AbstractVector
convcriteria::ConvergenceCriteria
callback::Function
optimizer::AbstractOptimizer
options
Expand All @@ -114,7 +116,6 @@ function MMAWorkspace(
optimizer::AbstractOptimizer,
x0::AbstractVector{T};
options = default_options(model, optimizer),
convcriteria::ConvergenceCriteria = KKTCriteria(),
plot_trace::Bool = false,
show_plot::Bool = plot_trace,
save_plot = nothing,
Expand All @@ -127,7 +128,7 @@ function MMAWorkspace(
# Convergence
λ = ones(getdim(getineqconstraints(model)))
solution = Solution(dualmodel, λ)
assess_convergence!(solution, model, options.tol, convcriteria)
assess_convergence!(solution, model, options.tol, options.convcriteria, options.verbose, 0)
correctsolution!(solution, model, options)

# Trace
Expand All @@ -147,7 +148,6 @@ function MMAWorkspace(
σ,
ρ,
tempx,
convcriteria,
callback,
optimizer,
options,
Expand Down Expand Up @@ -183,11 +183,12 @@ function Workspace(model::VecModel, optimizer::Union{MMA87, MMA02}, args...; kwa
end

function optimize!(workspace::MMAWorkspace)
@unpack dualmodel, solution, convcriteria = workspace
@unpack dualmodel, solution = workspace
@unpack callback, optimizer, options, trace = workspace
@unpack x0, σ, ρ, outer_iter, iter, fcalls = workspace
@unpack dualoptimizer = optimizer
@unpack dual_options, maxiter, outer_maxiter, auto_scale = options
@unpack convcriteria, verbose = options
@unpack prevx, x, g, λ = solution
best_solution = deepcopy(solution)

Expand Down Expand Up @@ -292,7 +293,7 @@ function optimize!(workspace::MMAWorkspace)
updatefg!(solution, fg, ∇fg)

# Check if the algorithm has converged
assess_convergence!(solution, model, options.tol, convcriteria)
assess_convergence!(solution, model, options.tol, convcriteria, verbose, iter)

# Callback, e.g. a trace plotting callback
callback(solution)
Expand Down
47 changes: 36 additions & 11 deletions test/mma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@ using NonconvexMMA, LinearAlgebra, Test, Zygote
f(x::AbstractVector) = x[2] < 0 ? Inf : sqrt(x[2])
g(x::AbstractVector, a, b) = (a*x[1] + b)^3 - x[2]

options = MMAOptions(
tol = Tolerance(kkt = 1e-6, f = 0.0),
s_init = 0.1,
)

@testset "Simple constraints" begin
m = Model(f)
addvar!(m, [0.0, 0.0], [10.0, 10.0])
Expand All @@ -16,7 +11,12 @@ options = MMAOptions(

@testset "MMA $(alg isa MMA87 ? "1987" : "2002")" for alg in (MMA87(), MMA02())
for convcriteria in (KKTCriteria(), IpoptCriteria())
r = NonconvexMMA.optimize(m, alg, [1.234, 2.345], options = options, convcriteria = convcriteria)
options = MMAOptions(;
tol = Tolerance(kkt = 1e-6, f = 0.0),
s_init = 0.1,
convcriteria,
)
r = NonconvexMMA.optimize(m, alg, [1.234, 2.345], options = options)
@test abs(r.minimum - sqrt(8/27)) < 1e-6
@test norm(r.minimizer - [1/3, 8/27]) < 1e-6
end
Expand All @@ -30,7 +30,12 @@ end

@testset "MMA $(alg isa MMA87 ? "1987" : "2002")" for alg in (MMA87(), MMA02())
for convcriteria in (KKTCriteria(), IpoptCriteria())
r = NonconvexMMA.optimize(m, alg, [1.234, 2.345], options = options, convcriteria = convcriteria)
options = MMAOptions(;
tol = Tolerance(kkt = 1e-6, f = 0.0),
s_init = 0.1,
convcriteria,
)
r = NonconvexMMA.optimize(m, alg, [1.234, 2.345], options = options)
@test abs(r.minimum - sqrt(8/27)) < 1e-6
@test norm(r.minimizer - [1/3, 8/27]) < 1e-6
end
Expand All @@ -46,7 +51,12 @@ end

@testset "MMA $(alg isa MMA87 ? "1987" : "2002")" for alg in (MMA87(), MMA02())
for convcriteria in (KKTCriteria(), IpoptCriteria())
r = NonconvexMMA.optimize(m, alg, [1.234, 2.345], options = options, convcriteria = convcriteria)
options = MMAOptions(;
tol = Tolerance(kkt = 1e-6, f = 0.0),
s_init = 0.1,
convcriteria,
)
r = NonconvexMMA.optimize(m, alg, [1.234, 2.345], options = options)
@test abs(r.minimum - sqrt(8/27)) < 1e-6
@test norm(r.minimizer - [1/3, 8/27]) < 1e-6
end
Expand All @@ -60,7 +70,12 @@ end

@testset "MMA $(alg isa MMA87 ? "1987" : "2002")" for alg in (MMA87(), MMA02())
for convcriteria in (KKTCriteria(), IpoptCriteria())
r = NonconvexMMA.optimize(m, alg, [1.234, 2.345], options = options, convcriteria = convcriteria)
options = MMAOptions(;
tol = Tolerance(kkt = 1e-6, f = 0.0),
s_init = 0.1,
convcriteria,
)
r = NonconvexMMA.optimize(m, alg, [1.234, 2.345], options = options)
@test abs(r.minimum - sqrt(8/27)) < 1e-6
@test norm(r.minimizer - [1/3, 8/27]) < 1e-6
end
Expand All @@ -74,7 +89,12 @@ end

@testset "MMA $(alg isa MMA87 ? "1987" : "2002")" for alg in (MMA87(), MMA02())
for convcriteria in (KKTCriteria(), IpoptCriteria())
r = NonconvexMMA.optimize(m, alg, [1.234, 2.345], options = options, convcriteria = convcriteria)
options = MMAOptions(;
tol = Tolerance(kkt = 1e-6, f = 0.0),
s_init = 0.1,
convcriteria,
)
r = NonconvexMMA.optimize(m, alg, [1.234, 2.345], options = options)
@test abs(r.minimum - sqrt(8/27)) < 1e-6
@test norm(r.minimizer - [1/3, 8/27]) < 1e-6
end
Expand All @@ -92,7 +112,12 @@ end
add_ineq_constraint!(m, x -> g(x, -1, 1))

for convcriteria in (KKTCriteria(), IpoptCriteria())
r = NonconvexMMA.optimize(m, MMA02(), [0.4, 0.5], options = options, convcriteria = convcriteria)
options = MMAOptions(;
tol = Tolerance(kkt = 1e-6, f = 0.0),
s_init = 0.1,
convcriteria,
)
r = NonconvexMMA.optimize(m, MMA02(), [0.4, 0.5], options = options)
@test abs(r.minimum - sqrt(8/27)) < 1e-6
@test norm(r.minimizer - [1/3, 8/27]) < 1e-6
end
Expand Down

0 comments on commit 9238530

Please sign in to comment.