Skip to content

Commit

Permalink
implementing an early stopper for cpsat solver, based on optimality g…
Browse files Browse the repository at this point in the history
…ap to best bound. this could be generalized to mip solver in further work
  • Loading branch information
g-poveda committed Jan 23, 2025
1 parent 98d31ae commit 65fb247
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 1 deletion.
35 changes: 35 additions & 0 deletions discrete_optimization/generic_tools/callbacks/early_stoppers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from discrete_optimization.generic_tools.callbacks.callback import Callback
from discrete_optimization.generic_tools.do_solver import SolverDO
from discrete_optimization.generic_tools.ortools_cpsat_tools import OrtoolsCpSatSolver
from discrete_optimization.generic_tools.result_storage.result_storage import (
ResultStorage,
)
Expand Down Expand Up @@ -69,3 +70,37 @@ def on_step_end(
return True
else:
return False


class ObjectiveGapCpSatSolver(Callback):
"""
Stop the cpsat solver according to some classical convergence criteria
It could be done differently (playing with parameters of cpsat directly)
"""

def __init__(
self,
objective_gap_rel: Optional[float] = None,
objective_gap_abs: Optional[float] = None,
):
self.objective_gap_rel = objective_gap_rel
self.objective_gap_abs = objective_gap_abs

def on_step_end(
self, step: int, res: ResultStorage, solver: OrtoolsCpSatSolver
) -> Optional[bool]:
best_sol = solver.clb.ObjectiveValue()
bound = solver.clb.BestObjectiveBound()
if self.objective_gap_abs is not None:
if abs(bound - best_sol) <= self.objective_gap_abs:
logger.debug(
f"Stopping search, absolute gap {abs(bound-best_sol)}<{self.objective_gap_abs}"
)
return True
if self.objective_gap_rel is not None:
if bound != 0:
if abs(bound - best_sol) / abs(bound) <= self.objective_gap_rel:
logger.debug(
f"Stopping search, relative gap {abs(bound-best_sol)/abs(bound)}<{self.objective_gap_rel}"
)
return True
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def on_step_end(
}
)
if solver.clb.ObjectiveValue() == solver.clb.BestObjectiveBound():
return False
return True

def on_solve_start(self, solver: OrtoolsCpSatSolver):
self.starting_time = perf_counter()
Expand Down
40 changes: 40 additions & 0 deletions tests/generic_tools/callbacks/test_earlystopobjective_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2025 AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import pytest

from discrete_optimization.generic_tools.callbacks.early_stoppers import (
ObjectiveGapCpSatSolver,
)
from discrete_optimization.generic_tools.cp_tools import ParametersCp
from discrete_optimization.knapsack.parser import get_data_available, parse_file
from discrete_optimization.knapsack.problem import KnapsackProblem, MobjKnapsackModel
from discrete_optimization.knapsack.solvers.cpsat import CpSatKnapsackSolver


def test_knapsack_ortools_objective_callback():
model_file = [f for f in get_data_available() if "ks_300_0" in f][0]
model: KnapsackProblem = parse_file(model_file, force_recompute_values=True)
model: MobjKnapsackModel = MobjKnapsackModel.from_knapsack(model)
solver = CpSatKnapsackSolver(model)
solver.init_model()
objective_gap_rel = 0.1
objective_gap_abs = 10
mycb = ObjectiveGapCpSatSolver(
objective_gap_rel=objective_gap_rel, objective_gap_abs=objective_gap_abs
)
parameters_cp = ParametersCp.default()
result_storage = solver.solve(
time_limit=10,
parameters_cp=parameters_cp,
callbacks=[mycb],
ortools_cpsat_solver_kwargs={"log_search_progress": True},
)
assert (
abs(solver.clb.ObjectiveValue() - solver.clb.BestObjectiveBound())
<= objective_gap_abs
or abs(solver.clb.ObjectiveValue() - solver.clb.BestObjectiveBound())
/ abs(solver.clb.BestObjectiveBound())
<= objective_gap_rel
)

0 comments on commit 65fb247

Please sign in to comment.