Skip to content

Commit

Permalink
Change expected output in tests, handle multigraphs
Browse files Browse the repository at this point in the history
  • Loading branch information
torressa committed Jul 16, 2024
1 parent 9223210 commit 3e878d4
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 120 deletions.
19 changes: 12 additions & 7 deletions src/gurobi_optimods/min_cost_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ def min_cost_flow_pandas(

source_label, target_label = arc_data.index.names

arc_data = (
arc_data.reset_index()
) # This is a workaround for duplicate entries being disallowed in gurobipy_pandas
multigraph = False
# This is a workaround for duplicate entries being disallowed in gurobipy_pandas
if arc_data.index.has_duplicates:
arc_data = arc_data.reset_index()
multigraph = True

arc_df = arc_data.gppd.add_vars(model, ub="capacity", obj="cost", name="flow")

balance_df = (
Expand All @@ -88,9 +91,9 @@ def min_cost_flow_pandas(
if model.Status in [GRB.INFEASIBLE, GRB.INF_OR_UNBD]:
raise ValueError("Unsatisfiable flows")

arc_df = arc_df.set_index(
["source", "target"]
) # Repair index that was reset above
if multigraph:
# Repair index that was reset above
arc_df = arc_df.set_index([source_label, target_label])
return model.ObjVal, arc_df["flow"].gppd.X


Expand Down Expand Up @@ -179,6 +182,8 @@ def min_cost_flow_networkx(G, *, create_env):
f"Solving min-cost flow with {len(G.nodes)} nodes and {len(G.edges)} edges"
)
with create_env() as env, gp.Model(env=env) as model:
multigraph = isinstance(G, nx.MultiGraph)

G = nx.MultiDiGraph(G)

edges, capacities, costs = gp.multidict(
Expand Down Expand Up @@ -221,7 +226,7 @@ def min_cost_flow_networkx(G, *, create_env):
raise ValueError("Unsatisfiable flows")

# Create a new Graph with selected edges in the matching
resulting_flow = nx.MultiDiGraph()
resulting_flow = nx.MultiDiGraph() if multigraph else nx.DiGraph()
resulting_flow.add_nodes_from(nodes)
resulting_flow.add_edges_from(
[(edge[0], edge[1], {"flow": v.X}) for edge, v in x.items() if v.X > 0.1]
Expand Down
40 changes: 14 additions & 26 deletions tests/test_graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import numpy as np


def _sort_key(x):
return str(x)


def check_solution_pandas(solution, candidates):
# Checks whether the solution (`pd.Series`) matches any of the list of
# candidates (containing `dict`)
if any(solution.to_dict() == c for c in candidates):
if any(
sorted(list(zip(solution.index.to_list(), solution.to_list())), key=_sort_key)
== sorted(c, key=_sort_key)
for c in candidates
):
return True
return False

Expand All @@ -21,30 +29,10 @@ def check_solution_scipy(solution, candidates):
def check_solution_networkx(solution, candidates):
# Checks whether the solution (`nx.DiGraph`) matches any of the list of
# candidates (containing tuples dict `{(i, j): data}`)
sol_dict = {(i, j): d for i, j, d in solution.edges(data=True)}
if any(sol_dict == c for c in candidates):
return True
return False


def check_solution_pandas_multi(solution, candidates):
# Checks whether the solution (`pd.Series`) matches any of the list of
# candidates (containing `pd.Series`)
if any(solution.reset_index().equals(c.reset_index()) for c in candidates):
sol_list = sorted(
[((i, j), data["flow"]) for i, j, data in solution.edges(data=True)],
key=_sort_key,
)
if any(sol_list == sorted(c, key=_sort_key) for c in candidates):
return True
return False


def check_solution_networkx_multi(solution, candidates):
# Checks whether the solution (`nx.DiGraph`) matches any of the list of
# candidates (containing tuples dict `{(i, j): data}`)
for candidate in candidates:

def edge_sort(row):
return (str(row[0]), str(row[1]))

if sorted(candidate, key=edge_sort) == sorted(
list(solution.edges(data=True)), key=edge_sort
):
return True
return False
180 changes: 93 additions & 87 deletions tests/test_min_cost_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@

from .test_graph_utils import (
check_solution_networkx,
check_solution_networkx_multi,
check_solution_pandas,
check_solution_pandas_multi,
check_solution_scipy,
)

Expand Down Expand Up @@ -93,7 +91,13 @@ def test_pandas(self):
cost, sol = mcf.min_cost_flow_pandas(edge_data, node_data)
sol = sol[sol > 0]
self.assertEqual(cost, 31)
candidate = {(0, 1): 1.0, (0, 2): 1.0, (1, 3): 1.0, (2, 4): 2.0, (4, 5): 2.0}
candidate = [
((0, 1), 1.0),
((0, 2), 1.0),
((1, 3), 1.0),
((2, 4), 2.0),
((4, 5), 2.0),
]
self.assertIsInstance(sol, pd.Series)
self.assertTrue(check_solution_pandas(sol, [candidate]))

Expand Down Expand Up @@ -125,13 +129,13 @@ def test_networkx(self):
G = datasets.simple_graph_networkx()
cost, sol = mcf.min_cost_flow_networkx(G)
self.assertEqual(cost, 31)
expected = {
(0, 1): {"flow": 1.0},
(0, 2): {"flow": 1.0},
(1, 3): {"flow": 1.0},
(2, 4): {"flow": 2.0},
(4, 5): {"flow": 2.0},
}
expected = [
((0, 1), 1.0),
((0, 2), 1.0),
((1, 3), 1.0),
((2, 4), 2.0),
((4, 5), 2.0),
]
self.assertIsInstance(sol, nx.Graph)
self.assertTrue(check_solution_networkx(sol, [expected]))

Expand All @@ -141,13 +145,13 @@ def test_networkx_renamed(self):
G = nx.relabel_nodes(G, {0: "s", 5: "t"})
cost, sol = mcf.min_cost_flow_networkx(G)
self.assertEqual(cost, 31)
expected = {
("s", 1): {"flow": 1.0},
("s", 2): {"flow": 1.0},
(1, 3): {"flow": 1.0},
(2, 4): {"flow": 2.0},
(4, "t"): {"flow": 2.0},
}
expected = [
(("s", 1), 1.0),
(("s", 2), 1.0),
((1, 3), 1.0),
((2, 4), 2.0),
((4, "t"), 2.0),
]
self.assertIsInstance(sol, nx.Graph)
self.assertTrue(check_solution_networkx(sol, [expected]))

Expand All @@ -158,24 +162,24 @@ def test_pandas(self):
cost, sol = mcf.min_cost_flow_pandas(edge_data, node_data)
sol = sol[sol > 0]
self.assertEqual(cost, 150)
candidate = {
(0, 1): 12.0,
(0, 2): 8.0,
(1, 3): 4.0,
(1, 2): 8.0,
(2, 3): 15.0,
(2, 4): 1.0,
(3, 4): 14.0,
}
candidate2 = {
(0, 1): 12.0,
(0, 2): 8.0,
(1, 3): 4.0,
(1, 2): 8.0,
(2, 3): 11.0,
(2, 4): 5.0,
(3, 4): 10.0,
}
candidate = [
((0, 1), 12.0),
((0, 2), 8.0),
((1, 3), 4.0),
((1, 2), 8.0),
((2, 3), 15.0),
((2, 4), 1.0),
((3, 4), 14.0),
]
candidate2 = [
((0, 1), 12.0),
((0, 2), 8.0),
((1, 3), 4.0),
((1, 2), 8.0),
((2, 3), 11.0),
((2, 4), 5.0),
((3, 4), 10.0),
]
self.assertTrue(check_solution_pandas(sol, [candidate, candidate2]))

def test_scipy(self):
Expand All @@ -197,24 +201,24 @@ def test_networkx(self):
G = load_graph2_networkx()
cost, sol = mcf.min_cost_flow_networkx(G)
self.assertEqual(cost, 150)
candidate = {
(0, 1): {"flow": 12.0},
(0, 2): {"flow": 8.0},
(1, 2): {"flow": 8.0},
(1, 3): {"flow": 4.0},
(2, 3): {"flow": 11.0},
(2, 4): {"flow": 5.0},
(3, 4): {"flow": 10.0},
}
candidate2 = {
(0, 1): {"flow": 12.0},
(0, 2): {"flow": 8.0},
(1, 3): {"flow": 4.0},
(1, 2): {"flow": 8.0},
(2, 3): {"flow": 15.0},
(2, 4): {"flow": 1.0},
(3, 4): {"flow": 14.0},
}
candidate = [
((0, 1), 12.0),
((0, 2), 8.0),
((1, 2), 8.0),
((1, 3), 4.0),
((2, 3), 11.0),
((2, 4), 5.0),
((3, 4), 10.0),
]
candidate2 = [
((0, 1), 12.0),
((0, 2), 8.0),
((1, 3), 4.0),
((1, 2), 8.0),
((2, 3), 15.0),
((2, 4), 1.0),
((3, 4), 14.0),
]
self.assertTrue(check_solution_networkx(sol, [candidate, candidate2]))

@unittest.skipIf(nx is None, "networkx is not installed")
Expand All @@ -223,23 +227,23 @@ def test_networkx_renamed(self):
G = nx.relabel_nodes(G, {0: "s", 4: "t"})
cost, sol = mcf.min_cost_flow_networkx(G)
self.assertEqual(cost, 150)
candidate = {
("s", 1): {"flow": 12.0},
("s", 2): {"flow": 8.0},
(1, 2): {"flow": 8.0},
(1, 3): {"flow": 4.0},
(2, 3): {"flow": 11.0},
(2, "t"): {"flow": 5.0},
(3, "t"): {"flow": 10.0},
}
candidate = [
(("s", 1), 12.0),
(("s", 2), 8.0),
((1, 2), 8.0),
((1, 3), 4.0),
((2, 3), 11.0),
((2, "t"), 5.0),
((3, "t"), 10.0),
]
candidate2 = {
("s", 1): {"flow": 12.0},
("s", 2): {"flow": 8.0},
(1, 3): {"flow": 4.0},
(1, 2): {"flow": 8.0},
(2, 3): {"flow": 15.0},
(2, "t"): {"flow": 1.0},
(3, "t"): {"flow": 14.0},
(("s", 1), 12.0),
(("s", 2), 8.0),
((1, 3), 4.0),
((1, 2), 8.0),
((2, 3), 15.0),
((2, "t"), 1.0),
((3, "t"), 14.0),
}
self.assertTrue(check_solution_networkx(sol, [candidate, candidate2]))

Expand All @@ -251,30 +255,32 @@ def test_pandas(self):
sol = sol[sol > 0]
self.assertEqual(cost, 49.0)

candidate = pd.DataFrame(
{
"source": [0, 0, 1, 1, 2, 2, 3, 2],
"target": [1, 2, 3, 2, 3, 4, 4, 3],
"flow": [12.0, 8.0, 4.0, 8.0, 10.0, 5.0, 10.0, 1.0],
}
).set_index(["source", "target"])

self.assertTrue(check_solution_pandas_multi(sol, [candidate]))
candidate = [
((0, 1), 12.0),
((0, 2), 8.0),
((1, 3), 4.0),
((1, 2), 8.0),
((2, 3), 10.0),
((2, 4), 5.0),
((3, 4), 10.0),
((2, 3), 1.0),
]
self.assertTrue(check_solution_pandas(sol, [candidate]))

@unittest.skipIf(nx is None, "networkx is not installed")
def test_networkx(self):
G = load_graph3_networkx(digraph=nx.MultiDiGraph)
cost, sol = mcf.min_cost_flow_networkx(G)
self.assertEqual(cost, 49.0)
candidate = [
(0, 1, {"flow": 12.0}),
(0, 2, {"flow": 8.0}),
(1, 3, {"flow": 4.0}),
(1, 2, {"flow": 8.0}),
(2, 3, {"flow": 10.0}),
(2, 3, {"flow": 1.0}),
(2, 4, {"flow": 5.0}),
(3, 4, {"flow": 10.0}),
((0, 1), 12.0),
((0, 2), 8.0),
((1, 3), 4.0),
((1, 2), 8.0),
((2, 3), 10.0),
((2, 3), 1.0),
((2, 4), 5.0),
((3, 4), 10.0),
]

self.assertTrue(check_solution_networkx_multi(sol, [candidate]))
self.assertTrue(check_solution_networkx(sol, [candidate]))

0 comments on commit 3e878d4

Please sign in to comment.