From 65fb2476b0e5924a0e09dbb13b419d8c6eb01350 Mon Sep 17 00:00:00 2001 From: poveda_g Date: Thu, 23 Jan 2025 16:45:48 +0100 Subject: [PATCH] implementing an early stopper for cpsat solver, based on optimality gap to best bound. this could be generalized to mip solver in further work --- .../generic_tools/callbacks/early_stoppers.py | 35 ++++++++++++++++ .../callbacks/stats_retrievers.py | 2 +- .../test_earlystopobjective_callback.py | 40 +++++++++++++++++++ 3 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 tests/generic_tools/callbacks/test_earlystopobjective_callback.py diff --git a/discrete_optimization/generic_tools/callbacks/early_stoppers.py b/discrete_optimization/generic_tools/callbacks/early_stoppers.py index a5d429e0a..ab9550896 100644 --- a/discrete_optimization/generic_tools/callbacks/early_stoppers.py +++ b/discrete_optimization/generic_tools/callbacks/early_stoppers.py @@ -8,6 +8,7 @@ from discrete_optimization.generic_tools.callbacks.callback import Callback from discrete_optimization.generic_tools.do_solver import SolverDO +from discrete_optimization.generic_tools.ortools_cpsat_tools import OrtoolsCpSatSolver from discrete_optimization.generic_tools.result_storage.result_storage import ( ResultStorage, ) @@ -69,3 +70,37 @@ def on_step_end( return True else: return False + + +class ObjectiveGapCpSatSolver(Callback): + """ + Stop the cpsat solver according to some classical convergence criteria + It could be done differently (playing with parameters of cpsat directly) + """ + + def __init__( + self, + objective_gap_rel: Optional[float] = None, + objective_gap_abs: Optional[float] = None, + ): + self.objective_gap_rel = objective_gap_rel + self.objective_gap_abs = objective_gap_abs + + def on_step_end( + self, step: int, res: ResultStorage, solver: OrtoolsCpSatSolver + ) -> Optional[bool]: + best_sol = solver.clb.ObjectiveValue() + bound = solver.clb.BestObjectiveBound() + if self.objective_gap_abs is not None: + if abs(bound - best_sol) <= self.objective_gap_abs: + logger.debug( + f"Stopping search, absolute gap {abs(bound-best_sol)}<{self.objective_gap_abs}" + ) + return True + if self.objective_gap_rel is not None: + if bound != 0: + if abs(bound - best_sol) / abs(bound) <= self.objective_gap_rel: + logger.debug( + f"Stopping search, relative gap {abs(bound-best_sol)/abs(bound)}<{self.objective_gap_rel}" + ) + return True diff --git a/discrete_optimization/generic_tools/callbacks/stats_retrievers.py b/discrete_optimization/generic_tools/callbacks/stats_retrievers.py index 7a15f911a..b3f9ea007 100644 --- a/discrete_optimization/generic_tools/callbacks/stats_retrievers.py +++ b/discrete_optimization/generic_tools/callbacks/stats_retrievers.py @@ -66,7 +66,7 @@ def on_step_end( } ) if solver.clb.ObjectiveValue() == solver.clb.BestObjectiveBound(): - return False + return True def on_solve_start(self, solver: OrtoolsCpSatSolver): self.starting_time = perf_counter() diff --git a/tests/generic_tools/callbacks/test_earlystopobjective_callback.py b/tests/generic_tools/callbacks/test_earlystopobjective_callback.py new file mode 100644 index 000000000..f466e0d83 --- /dev/null +++ b/tests/generic_tools/callbacks/test_earlystopobjective_callback.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025 AIRBUS and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +from discrete_optimization.generic_tools.callbacks.early_stoppers import ( + ObjectiveGapCpSatSolver, +) +from discrete_optimization.generic_tools.cp_tools import ParametersCp +from discrete_optimization.knapsack.parser import get_data_available, parse_file +from discrete_optimization.knapsack.problem import KnapsackProblem, MobjKnapsackModel +from discrete_optimization.knapsack.solvers.cpsat import CpSatKnapsackSolver + + +def test_knapsack_ortools_objective_callback(): + model_file = [f for f in get_data_available() if "ks_300_0" in f][0] + model: KnapsackProblem = parse_file(model_file, force_recompute_values=True) + model: MobjKnapsackModel = MobjKnapsackModel.from_knapsack(model) + solver = CpSatKnapsackSolver(model) + solver.init_model() + objective_gap_rel = 0.1 + objective_gap_abs = 10 + mycb = ObjectiveGapCpSatSolver( + objective_gap_rel=objective_gap_rel, objective_gap_abs=objective_gap_abs + ) + parameters_cp = ParametersCp.default() + result_storage = solver.solve( + time_limit=10, + parameters_cp=parameters_cp, + callbacks=[mycb], + ortools_cpsat_solver_kwargs={"log_search_progress": True}, + ) + assert ( + abs(solver.clb.ObjectiveValue() - solver.clb.BestObjectiveBound()) + <= objective_gap_abs + or abs(solver.clb.ObjectiveValue() - solver.clb.BestObjectiveBound()) + / abs(solver.clb.BestObjectiveBound()) + <= objective_gap_rel + )