diff --git a/ortools/linear_solver/xpress_interface.cc b/ortools/linear_solver/xpress_interface.cc index 2fc90393b29..0282e5cf574 100644 --- a/ortools/linear_solver/xpress_interface.cc +++ b/ortools/linear_solver/xpress_interface.cc @@ -227,7 +227,7 @@ class XpressMPCallbackContext : public MPCallbackContext { : xprsprob_(xprsprob), event_(event), num_nodes_(num_nodes), - variable_values_(0) {}; + variable_values_(0){}; // Implementation of the interface. MPCallbackEvent Event() override { return event_; }; @@ -260,7 +260,7 @@ class XpressMPCallbackContext : public MPCallbackContext { // Wraps the MPCallback in order to catch and store exceptions class MPCallbackWrapper { public: - explicit MPCallbackWrapper(MPCallback* callback) : callback_(callback) {}; + explicit MPCallbackWrapper(MPCallback* callback) : callback_(callback){}; MPCallback* GetCallback() const { return callback_; } // Since our (C++) call-back functions are called from the XPRESS (C) code, // exceptions thrown in our call-back code are not caught by XPRESS. @@ -328,6 +328,7 @@ class XpressInterface : public MPSolverInterface { void SetConstraintBounds(int row_index, double lb, double ub) override; void AddRowConstraint(MPConstraint* ct) override; + bool AddIndicatorConstraint(MPConstraint* ct) override; void AddVariable(MPVariable* var) override; void SetCoefficient(MPConstraint* constraint, MPVariable const* variable, double new_value, double old_value) override; @@ -1541,6 +1542,10 @@ void XpressInterface::ExtractNewConstraints() { unique_ptr sense(new char[chunk]); unique_ptr rhs(new double[chunk]); unique_ptr rngval(new double[chunk]); + int n_indicators = 0; + std::vector indicator_rowind; + std::vector indicator_colind; + std::vector indicator_complement; // Loop over the new constraints, collecting rows for up to // CHUNK constraints into the arrays so that adding constraints @@ -1576,6 +1581,14 @@ void XpressInterface::ExtractNewConstraints() { ++nextNz; } } + + // Detect & store indicator constraints + if (ct->indicator_variable() != nullptr) { + n_indicators++; + indicator_rowind.push_back(nextRow); + indicator_colind.push_back(ct->indicator_variable()->index()); + indicator_complement.push_back(ct->indicator_value() ? 1 : -1); + } } if (nextRow > 0) { CHECK_STATUS(XPRSaddrows(mLp, nextRow, nextNz, sense.get(), rhs.get(), @@ -1583,6 +1596,11 @@ void XpressInterface::ExtractNewConstraints() { rmatval.get())); } } + + // Set indicator constraints in XPRESS + CHECK_STATUS(XPRSsetindicators(mLp, n_indicators, indicator_rowind.data(), + indicator_colind.data(), + indicator_complement.data())); } catch (...) { // Undo all changes in case of error. int const rows = getnumrows(mLp); @@ -2288,4 +2306,9 @@ double XpressMPCallbackContext::SuggestSolution( return NAN; } +bool XpressInterface::AddIndicatorConstraint(MPConstraint* ct) { + InvalidateModelSynchronization(); + return !IsContinuous(); +} + } // namespace operations_research diff --git a/ortools/linear_solver/xpress_interface_test.cc b/ortools/linear_solver/xpress_interface_test.cc index 80916e72fe1..40ad792c33e 100644 --- a/ortools/linear_solver/xpress_interface_test.cc +++ b/ortools/linear_solver/xpress_interface_test.cc @@ -285,9 +285,9 @@ class MyMPCallback : public MPCallback { MyMPCallback(MPSolver* mpSolver, bool should_throw) : MPCallback(false, false), mpSolver_(mpSolver), - should_throw_(should_throw) {}; + should_throw_(should_throw){}; - ~MyMPCallback() override {}; + ~MyMPCallback() override{}; void RunCallback(MPCallbackContext* callback_context) override { if (should_throw_) { @@ -1411,6 +1411,59 @@ TEST_F(XpressFixtureMIP, CallbackThrowsException) { ASSERT_NE(errors.find(expected_error), std::string::npos); } +TEST_F(XpressFixtureMIP, IndicatorConstraint0) { + solver.EnableOutput(); + // Maximize x <= 100 + auto x = solver.MakeNumVar(0, 100, "x"); + solver.MutableObjective()->SetMaximization(); + solver.MutableObjective()->SetCoefficient(x, 1); + // With indicator constraint + // if var = 0, then x <= 10 + auto var = solver.MakeBoolVar("indicator_var"); + auto ct = solver.MakeIndicatorConstraint(0, 10, "test", var, false); + ct->SetCoefficient(x, 1); + + // Leave var free ==> x = 100 + solver.Solve(); + EXPECT_EQ(var->solution_value(), 1); + EXPECT_EQ(x->solution_value(), 100); + + // Force var to 0 ==> x = 10 + // WARNING : can't use var->SetUB(0), because then XPRESS would automatically + // change its type to continuous, then fail on indicator variable evaluation. + // We have to add a constraint instead. + ct = solver.MakeRowConstraint(0, 0, "set_indicator_var_to_0"); + ct->SetCoefficient(var, 1); + solver.Solve(); + EXPECT_EQ(x->solution_value(), 10); +} + +TEST_F(XpressFixtureMIP, IndicatorConstraint1) { + // Maximize x <= 100 + auto x = solver.MakeNumVar(0, 100, "x"); + solver.MutableObjective()->SetMaximization(); + solver.MutableObjective()->SetCoefficient(x, 1); + // With indicator constraint + // if var = 1, then x <= 10 + auto var = solver.MakeBoolVar("indicator_var"); + auto ct = solver.MakeIndicatorConstraint(0, 10, "test", var, true); + ct->SetCoefficient(x, 1); + + // Leave var free ==> x = 100 + solver.Solve(); + EXPECT_EQ(var->solution_value(), 0); + EXPECT_EQ(x->solution_value(), 100); + + // Force var to 0 ==> x = 10 + // WARNING : can't use var->SetLB(1), because then XPRESS would automatically + // change its type to continuous, then fail on indicator variable evaluation. + // We have to add a constraint instead. + ct = solver.MakeRowConstraint(1, 1, "set_indicator_var_to_1"); + ct->SetCoefficient(var, 1); + solver.Solve(); + EXPECT_EQ(x->solution_value(), 10); +} + } // namespace operations_research int main(int argc, char** argv) { diff --git a/ortools/xpress/environment.cc b/ortools/xpress/environment.cc index 13b9fccb4da..c4e8247fee2 100644 --- a/ortools/xpress/environment.cc +++ b/ortools/xpress/environment.cc @@ -73,6 +73,7 @@ std::function XPRSdelrows = nullptr; std::function XPRSaddcols = nullptr; std::function XPRSaddnames = nullptr; +std::function XPRSsetindicators = nullptr; std::function XPRSgetnames = nullptr; std::function XPRSdelcols = nullptr; std::function XPRSchgcoltype = nullptr; @@ -137,6 +138,7 @@ void LoadXpressFunctions(DynamicLibrary* xpress_dynamic_library) { xpress_dynamic_library->GetFunction(&XPRSdelrows, "XPRSdelrows"); xpress_dynamic_library->GetFunction(&XPRSaddcols, "XPRSaddcols"); xpress_dynamic_library->GetFunction(&XPRSaddnames, "XPRSaddnames"); + xpress_dynamic_library->GetFunction(&XPRSsetindicators, "XPRSsetindicators"); xpress_dynamic_library->GetFunction(&XPRSgetnames, "XPRSgetnames"); xpress_dynamic_library->GetFunction(&XPRSdelcols, "XPRSdelcols"); xpress_dynamic_library->GetFunction(&XPRSchgcoltype, "XPRSchgcoltype"); diff --git a/ortools/xpress/environment.h b/ortools/xpress/environment.h index 8bcb3a10767..f585985b188 100644 --- a/ortools/xpress/environment.h +++ b/ortools/xpress/environment.h @@ -469,6 +469,7 @@ extern std::function XPRSdelrows; extern std::function XPRSaddcols; extern std::function XPRSaddnames; +extern std::function XPRSsetindicators; extern std::function XPRSgetnames; extern std::function XPRSdelcols; extern std::function XPRSchgcoltype;