-
Notifications
You must be signed in to change notification settings - Fork 201
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
Add linear program solver based on the restarted Halpern primal-dual …
…hybrid gradient (rHPDHG) algorithm.
1 parent
3d8c391
commit 8638ca6
Showing
12 changed files
with
591 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,6 +129,7 @@ disable=R, | |
wrong-import-order, | ||
xrange-builtin, | ||
zip-builtin-not-iterating, | ||
invalid-name, | ||
|
||
|
||
[REPORTS] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
Linear programming | ||
================== | ||
|
||
.. currentmodule:: optax.linprog | ||
|
||
.. autosummary:: | ||
rhpdhg | ||
|
||
|
||
Restarted Halpern primal-dual hybrid gradient method | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
.. autofunction:: rhpdhg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""The linear programming sub-package.""" | ||
|
||
# pylint:disable=g-importing-member | ||
|
||
from optax.linprog._rhpdhg import solve_general as rhpdhg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""The restarted Halpern primal-dual hybrid gradient method.""" | ||
|
||
from jax import lax, numpy as jnp | ||
from optax import tree_utils as otu | ||
|
||
|
||
def solve_canonical( | ||
c, A, b, iters, reflect=True, restarts=True, tau=None, sigma=None | ||
): | ||
r"""Solves a linear program using the restarted Halpern primal-dual hybrid | ||
gradient (RHPDHG) method. | ||
Minimizes :math:`c \cdot x` subject to :math:`A x = b` and :math:`x \geq 0`. | ||
See also `MPAX <https://github.com/MIT-Lu-Lab/MPAX>`_. | ||
Args: | ||
c: Cost vector. | ||
A: Equality constraint matrix. | ||
b: Equality constraint vector. | ||
iters: Number of iterations to run the solver for. | ||
reflect: Use reflection. See paper for details. | ||
restarts: Use restarts. See paper for details. | ||
tau: Primal step size. See paper for details. | ||
sigma: Dual step size. See paper for details. | ||
Returns: | ||
A dictionary whose entries are as follows: | ||
- primal: The final primal solution. | ||
- dual: The final dual solution. | ||
- primal_iterates: The primal iterates. | ||
- dual_iterates: The dual iterates. | ||
Examples: | ||
>>> from jax import numpy as jnp | ||
>>> import optax | ||
>>> c = -jnp.array([2, 1]) | ||
>>> A = jnp.zeros([0, 2]) | ||
>>> b = jnp.zeros(0) | ||
>>> G = jnp.array([[3, 1], [1, 1], [1, 4]]) | ||
>>> h = jnp.array([21, 9, 24]) | ||
>>> x = optax.linprog.rhpdhg(c, A, b, G, h, 1_000_000)['primal'] | ||
>>> print(x[0]) | ||
5.99... | ||
>>> print(x[1]) | ||
2.99... | ||
References: | ||
Haihao Lu, Jinwen Yang, `Restarted Halpern PDHG for Linear Programming | ||
<https://arxiv.org/abs/2407.16144>`_, 2024 | ||
""" | ||
|
||
if tau is None or sigma is None: | ||
A_norm = jnp.linalg.norm(A, axis=(0, 1), ord=2) | ||
if tau is None: | ||
tau = 1 / (2 * A_norm) | ||
if sigma is None: | ||
sigma = 1 / (2 * A_norm) | ||
|
||
def T(z): | ||
# primal dual hybrid gradient (PDHG) | ||
x, y = z | ||
xn = x + tau * (y @ A - c) | ||
xn = xn.clip(min=0) | ||
yn = y + sigma * (b - A @ (2 * xn - x)) | ||
return xn, yn | ||
|
||
def H(z, k, z0): | ||
# Halpern PDHG | ||
Tz = T(z) | ||
if reflect: | ||
zc = otu.tree_sub(otu.tree_scalar_mul(2, Tz), z) | ||
else: | ||
zc = Tz | ||
kp2 = k + 2 | ||
zn = otu.tree_add( | ||
otu.tree_scalar_mul((k + 1) / kp2, zc), | ||
otu.tree_scalar_mul(1 / kp2, z0), | ||
) | ||
return zn, Tz | ||
|
||
def update(carry, _): | ||
z, k, z0, d0 = carry | ||
zn, Tz = H(z, k, z0) | ||
|
||
if restarts: | ||
d = otu.tree_l2_norm(otu.tree_sub(z, Tz), squared=True) | ||
restart = d <= d0 * jnp.exp(-2) | ||
new_carry = otu.tree_where( | ||
restart, | ||
(zn, 0, zn, d), | ||
(zn, k + 1, z0, d0), | ||
) | ||
else: | ||
new_carry = zn, k + 1, z0, d0 | ||
|
||
return new_carry, z | ||
|
||
def run(): | ||
m, n = A.shape | ||
x = jnp.zeros(n) | ||
y = jnp.zeros(m) | ||
z0 = x, y | ||
d0 = otu.tree_l2_norm(otu.tree_sub(z0, T(z0)), squared=True) | ||
(z, _, _, _), zs = lax.scan(update, (z0, 0, z0, d0), length=iters) | ||
x, y = z | ||
xs, ys = zs | ||
return { | ||
"primal": x, | ||
"dual": y, | ||
"primal_iterates": xs, | ||
"dual_iterates": ys, | ||
} | ||
|
||
return run() | ||
|
||
|
||
def general_to_canonical(c, A, b, G, h): | ||
"""Converts a linear program from general form to canonical form. | ||
The solution to the new linear program will consist of the concatenation of | ||
- the positive part of x | ||
- the negative part of x | ||
- slacks | ||
That is, we go from | ||
Minimize c · x subject to | ||
A x = b | ||
G x ≤ h | ||
to | ||
Minimize c · (x⁺ - x⁻) subject to | ||
A (x⁺ - x⁻) = b | ||
G (x⁺ - x⁻) + s = h | ||
x⁺, x⁻, s ≥ 0 | ||
Args: | ||
c: Cost vector. | ||
A: Equality constraint matrix. | ||
b: Equality constraint vector. | ||
G: Inequality constraint matrix. | ||
h: Inequality constraint vector. | ||
Returns: | ||
A triple (c', A', b') representing the corresponding canonical form. | ||
""" | ||
c_can = jnp.concatenate([c, -c, jnp.zeros(h.size)]) | ||
G_ = jnp.concatenate([G, -G, jnp.eye(h.size)], 1) | ||
A_ = jnp.concatenate([A, -A, jnp.zeros([b.size, h.size])], 1) | ||
A_can = jnp.concatenate([A_, G_], 0) | ||
b_can = jnp.concatenate([b, h]) | ||
return c_can, A_can, b_can | ||
|
||
|
||
def solve_general( | ||
c, A, b, G, h, iters, reflect=True, restarts=True, tau=None, sigma=None | ||
): | ||
r"""Solves a linear program using the restarted Halpern primal-dual hybrid | ||
gradient (RHPDHG) method. | ||
Minimizes :math:`c \cdot x` subject to :math:`A x = b` and :math:`G x \leq h`. | ||
See also `MPAX <https://github.com/MIT-Lu-Lab/MPAX>`_. | ||
Args: | ||
c: Cost vector. | ||
A: Equality constraint matrix. | ||
b: Equality constraint vector. | ||
G: Inequality constraint matrix. | ||
h: Inequality constraint vector. | ||
iters: Number of iterations to run the solver for. | ||
reflect: Use reflection. See paper for details. | ||
restarts: Use restarts. See paper for details. | ||
tau: Primal step size. See paper for details. | ||
sigma: Dual step size. See paper for details. | ||
Returns: | ||
A dictionary whose entries are as follows: | ||
- primal: The final primal solution. | ||
- slacks: The final primal slack values. | ||
- canonical_result: The result for the canonical program that was used | ||
internally to find this solution. See paper for details. | ||
References: | ||
Haihao Lu, Jinwen Yang, `Restarted Halpern PDHG for Linear Programming | ||
<https://arxiv.org/abs/2407.16144>`_, 2024 | ||
""" | ||
canonical = general_to_canonical(c, A, b, G, h) | ||
result = solve_canonical(*canonical, iters, reflect, restarts, tau, sigma) | ||
x_pos, x_neg, slacks = jnp.split(result["primal"], [c.size, c.size * 2]) | ||
return { | ||
"primal": x_pos - x_neg, | ||
"slacks": slacks, | ||
"canonical_result": result, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for the restarted Halpern primal-dual hybrid gradient method.""" | ||
|
||
from functools import partial | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
import jax | ||
from jax import numpy as jnp | ||
import numpy as np | ||
import cvxpy as cp | ||
|
||
from optax.linprog import rhpdhg | ||
|
||
|
||
def solve_cvxpy(c, A, b, G, h): | ||
x = cp.Variable(c.size) | ||
constraints = [] | ||
if A.shape[0] > 0: | ||
constraints.append(A @ x == b) | ||
if G.shape[0] > 0: | ||
constraints.append(G @ x <= h) | ||
objective = cp.Minimize(c @ x) | ||
problem = cp.Problem(objective, constraints) | ||
problem.solve(solver='GLPK') | ||
return x.value, problem.status | ||
|
||
|
||
class RHPDHGTest(parameterized.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.f = jax.jit(partial(rhpdhg, iters=100_000)) | ||
|
||
@parameterized.parameters( | ||
dict(n_vars=n_vars, n_eq=n_eq, n_ineq=n_ineq) | ||
for n_vars in range(8) | ||
for n_eq in range(n_vars) | ||
for n_ineq in range(8) | ||
if n_eq + n_ineq >= n_vars | ||
# Make sure set of solvable LPs with these shapes is not null in measure. | ||
) | ||
def test_hungarian_algorithm(self, n_vars, n_eq, n_ineq): | ||
# Find a solvable LP. | ||
while True: | ||
|
||
c = np.random.normal(size=(n_vars,)) | ||
A = np.random.normal(size=(n_eq, n_vars)) | ||
b = np.random.normal(size=(n_eq,)) | ||
G = np.random.normal(size=(n_ineq, n_vars)) | ||
h = np.random.normal(size=(n_ineq,)) | ||
|
||
# For numerical testing purposes, constrain x to [-limit, limit]. | ||
limit = 5 | ||
G = jnp.concatenate([G, jnp.eye(n_vars), -jnp.eye(n_vars)]) | ||
h = jnp.concatenate([h, jnp.full(n_vars * 2, limit)]) | ||
|
||
r, status = solve_cvxpy(c, A, b, G, h) | ||
|
||
if status == 'optimal': | ||
break | ||
|
||
result = self.f(c, A, b, G, h) | ||
x = result['primal'] | ||
|
||
rtol = 1e-2 | ||
atol = 1e-2 | ||
|
||
with self.subTest('x approximately satisfies equality constraints'): | ||
np.testing.assert_allclose(A @ x, b, rtol=rtol, atol=atol) | ||
|
||
with self.subTest('x approximately satisfies inequality constraints'): | ||
np.testing.assert_allclose((G @ x).clip(min=h), h, rtol=rtol, atol=atol) | ||
|
||
with self.subTest('x is approximately as good as the reference solution'): | ||
cx = c @ x | ||
cr = c @ r | ||
np.testing.assert_allclose(cx.clip(min=cr), cr, rtol=rtol, atol=atol) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters