forked from airbus/discrete-optimization
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding CPSAT native model for coloring problem
- added example and unit test associated.
- Loading branch information
Showing
4 changed files
with
260 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
157 changes: 157 additions & 0 deletions
157
discrete_optimization/coloring/solvers/coloring_cpsat_solver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# Copyright (c) 2024 AIRBUS and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from enum import Enum | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from ortools.sat.python.cp_model import CpModel, CpSolverSolutionCallback, IntVar | ||
|
||
from discrete_optimization.coloring.coloring_model import ColoringSolution | ||
from discrete_optimization.coloring.solvers.coloring_solver_with_starting_solution import ( | ||
SolverColoringWithStartingSolution, | ||
) | ||
from discrete_optimization.generic_tools.do_problem import ( | ||
ParamsObjectiveFunction, | ||
Problem, | ||
Solution, | ||
) | ||
from discrete_optimization.generic_tools.ortools_cpsat_tools import OrtoolsCPSatSolver | ||
|
||
|
||
class ModelingCPSat(Enum): | ||
BINARY = 0 | ||
INTEGER = 1 | ||
|
||
|
||
class ColoringCPSatSolver(OrtoolsCPSatSolver, SolverColoringWithStartingSolution): | ||
def __init__( | ||
self, | ||
problem: Problem, | ||
params_objective_function: Optional[ParamsObjectiveFunction] = None, | ||
**kwargs: Any, | ||
): | ||
super().__init__(problem, params_objective_function, **kwargs) | ||
self.modeling: Optional[ModelingCPSat] = None | ||
self.variables: Dict[str, Union[List[IntVar], List[Dict[int, IntVar]]]] = {} | ||
|
||
def retrieve_solution(self, cpsolvercb: CpSolverSolutionCallback) -> Solution: | ||
if self.modeling == ModelingCPSat.INTEGER: | ||
return ColoringSolution( | ||
problem=self.problem, | ||
colors=[ | ||
cpsolvercb.Value(self.variables["colors"][i]) | ||
for i in range(len(self.variables["colors"])) | ||
], | ||
) | ||
if self.modeling == ModelingCPSat.BINARY: | ||
colors = [None for i in range(len(self.variables["colors"]))] | ||
for i in range(len(self.variables["colors"])): | ||
c = next( | ||
( | ||
j | ||
for j in self.variables["colors"][i] | ||
if cpsolvercb.Value(self.variables["colors"][i][j]) == 1 | ||
), | ||
None, | ||
) | ||
colors[i] = c | ||
return ColoringSolution(problem=self.problem, colors=colors) | ||
|
||
def init_model_binary(self, nb_colors: int, **kwargs): | ||
cp_model = CpModel() | ||
allocation_binary = [ | ||
{ | ||
j: cp_model.NewBoolVar(name=f"allocation_{i}_{j}") | ||
for j in range(nb_colors) | ||
} | ||
for i in range(self.problem.number_of_nodes) | ||
] | ||
for i in range(len(allocation_binary)): | ||
cp_model.AddExactlyOne( | ||
[allocation_binary[i][j] for j in allocation_binary[i]] | ||
) | ||
if self.problem.has_constraints_coloring: | ||
for node in self.problem.constraints_coloring.color_constraint: | ||
ind = self.problem.index_nodes_name[node] | ||
col = self.problem.constraints_coloring.color_constraint[node] | ||
cp_model.Add(allocation_binary[ind][col] == 1) | ||
for c in allocation_binary[ind]: | ||
if c != col: | ||
# Could do it more efficiently | ||
cp_model.Add(allocation_binary[ind][c] == 0) | ||
for edge in self.problem.graph.edges: | ||
ind1 = self.problem.index_nodes_name[edge[0]] | ||
ind2 = self.problem.index_nodes_name[edge[1]] | ||
for team in allocation_binary[ind1]: | ||
if team in allocation_binary[ind2]: | ||
cp_model.AddForbiddenAssignments( | ||
[allocation_binary[ind1][team], allocation_binary[ind2][team]], | ||
[(1, 1)], | ||
) | ||
used = [cp_model.NewBoolVar(f"used_{j}") for j in range(nb_colors)] | ||
if self.problem.use_subset: | ||
indexes_subset = self.problem.index_subset_nodes | ||
else: | ||
indexes_subset = range(len(allocation_binary)) | ||
for i in indexes_subset: | ||
for j in allocation_binary[i]: | ||
cp_model.Add(used[j] >= allocation_binary[i][j]) | ||
cp_model.Minimize(sum(used)) | ||
self.cp_model = cp_model | ||
self.variables["colors"] = allocation_binary | ||
self.variables["used"] = used | ||
|
||
def init_model_integer(self, nb_colors: int, **kwargs): | ||
cp_model = CpModel() | ||
variables = [ | ||
cp_model.NewIntVar(0, nb_colors - 1, name=f"c_{i}") | ||
for i in range(self.problem.number_of_nodes) | ||
] | ||
for edge in self.problem.graph.edges: | ||
ind_0 = self.problem.index_nodes_name[edge[0]] | ||
ind_1 = self.problem.index_nodes_name[edge[1]] | ||
cp_model.Add(variables[ind_0] != variables[ind_1]) | ||
if self.problem.has_constraints_coloring: | ||
for node in self.problem.constraints_coloring.color_constraint: | ||
ind = self.problem.index_nodes_name[node] | ||
cp_model.Add( | ||
variables[ind] | ||
== self.problem.constraints_coloring.color_constraint[node] | ||
) | ||
used = [cp_model.NewBoolVar(name=f"used_{c}") for c in range(nb_colors)] | ||
|
||
def add_indicator(vars, value, presence_value, model): | ||
bool_vars = [] | ||
for var in vars: | ||
bool_var = model.NewBoolVar("") | ||
model.Add(var == value).OnlyEnforceIf(bool_var) | ||
model.Add(var != value).OnlyEnforceIf(bool_var.Not()) | ||
bool_vars.append(bool_var) | ||
model.AddMaxEquality(presence_value, bool_vars) | ||
|
||
for j in range(nb_colors): | ||
if self.problem.use_subset: | ||
indexes = self.problem.index_subset_nodes | ||
vars = [variables[i] for i in indexes] | ||
else: | ||
vars = variables | ||
add_indicator(vars, j, used[j], cp_model) | ||
cp_model.Minimize(sum(used)) | ||
self.cp_model = cp_model | ||
self.variables["colors"] = variables | ||
self.variables["used"] = used | ||
|
||
def init_model(self, **args: Any) -> None: | ||
modeling = args.get("modeling", ModelingCPSat.INTEGER) | ||
assert isinstance(modeling, ModelingCPSat) | ||
if "nb_colors" not in args: | ||
solution = self.get_starting_solution(**args) | ||
nb_colors = self.problem.count_colors_all_index(solution.colors) | ||
args["nb_colors"] = nb_colors | ||
else: | ||
nb_colors = args["nb_colors"] | ||
if modeling == ModelingCPSat.BINARY: | ||
self.init_model_binary(**args) | ||
if modeling == ModelingCPSat.INTEGER: | ||
self.init_model_integer(**args) | ||
self.modeling = modeling |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) 2024 AIRBUS and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import logging | ||
import os | ||
|
||
from discrete_optimization.coloring.coloring_model import ( | ||
ConstraintsColoring, | ||
transform_coloring_problem, | ||
) | ||
|
||
os.environ["DO_SKIP_MZN_CHECK"] = "1" | ||
|
||
import logging | ||
|
||
from discrete_optimization.coloring.coloring_parser import ( | ||
get_data_available, | ||
parse_file, | ||
) | ||
from discrete_optimization.coloring.coloring_plot import plot_coloring_solution, plt | ||
from discrete_optimization.coloring.solvers.coloring_cpsat_solver import ( | ||
ColoringCPSatSolver, | ||
ModelingCPSat, | ||
) | ||
from discrete_optimization.generic_tools.callbacks.loggers import NbIterationTracker | ||
from discrete_optimization.generic_tools.cp_tools import ParametersCP | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
def run_cpsat_coloring(): | ||
logging.basicConfig(level=logging.INFO) | ||
file = [f for f in get_data_available() if "gc_70_5" in f][0] | ||
color_problem = parse_file(file) | ||
solver = ColoringCPSatSolver(color_problem, params_objective_function=None) | ||
solver.init_model(nb_colors=20, modeling=ModelingCPSat.BINARY) | ||
p = ParametersCP.default() | ||
p.time_limit = 20 | ||
result_store = solver.solve( | ||
callbacks=[NbIterationTracker(step_verbosity_level=logging.INFO)], | ||
parameters_cp=p, | ||
) | ||
print("Status solver : ", solver.get_status_solver()) | ||
solution, fit = result_store.get_best_solution_fit() | ||
plot_coloring_solution(solution) | ||
plt.show() | ||
print(solution, fit) | ||
print("Evaluation : ", color_problem.evaluate(solution)) | ||
print("Satisfy : ", color_problem.satisfy(solution)) | ||
|
||
|
||
def run_cpsat_coloring_with_constraints(): | ||
logging.basicConfig(level=logging.INFO) | ||
file = [f for f in get_data_available() if "gc_20_1" in f][0] | ||
color_problem = parse_file(file) | ||
color_problem = transform_coloring_problem( | ||
color_problem, | ||
subset_nodes=set(range(10)), | ||
constraints_coloring=ConstraintsColoring(color_constraint={0: 0, 1: 1, 2: 2}), | ||
) | ||
solver = ColoringCPSatSolver(color_problem) | ||
solver.init_model(nb_colors=20) | ||
p = ParametersCP.default() | ||
p.time_limit = 20 | ||
result_store = solver.solve(parameters_cp=p) | ||
solution, fit = result_store.get_best_solution_fit() | ||
print("Status solver : ", solver.get_status_solver()) | ||
plot_coloring_solution(solution) | ||
plt.show() | ||
print(solution, fit) | ||
print("Evaluation : ", color_problem.evaluate(solution)) | ||
print("Satisfy : ", color_problem.satisfy(solution)) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_cpsat_coloring() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters