Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename CallSolver --> CreateAutoShardingSolverRequestAndCallSolver and CallORToolsSolver --> FormulateAndSolveMIPFromAutoShardingSolverRequest to better capture the function implementation. #17676

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1747,7 +1747,7 @@ std::unique_ptr<StrategyGroup> CreateReshapeStrategies(
return strategy_group;
}

AutoShardingSolverResult CallSolver(
AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
Expand Down Expand Up @@ -1969,7 +1969,7 @@ AutoShardingSolverResult CallSolver(

PopulateTemporalValues(cost_graph, request);

return CallORToolsSolver(request);
return FormulateAndSolveMIPFromSolverRequest(request);
}

void CheckHloSharding(
Expand Down
14 changes: 7 additions & 7 deletions xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ AutoShardingSolverResult Solve(
const AutoShardingOption& option, absl::string_view request_prefix,
const absl::flat_hash_map<std::string, HloSharding>&
sharding_propagation_solution) {
return CallSolver(hlo_module, hlo_live_range, strategy_map, strategy_groups,
cost_graph, alias_set, node_intervals, edge_intervals,
node_groups, edge_groups, /*s_hint*/ {},
/*compute_iis*/ true, option.solver_timeout_in_seconds,
option, /*max_cost*/ std::nullopt, request_prefix,
sharding_propagation_solution,
/*deterministic mode*/ true);
return CreateAutoShardingSolverRequestAndCallSolver(
hlo_module, hlo_live_range, strategy_map, strategy_groups, cost_graph,
alias_set, node_intervals, edge_intervals, node_groups, edge_groups,
/*s_hint*/ {},
/*compute_iis*/ true, option.solver_timeout_in_seconds, option,
/*max_cost*/ std::nullopt, request_prefix, sharding_propagation_solution,
/*deterministic mode*/ true);
}

void PopulateTemporalValues(const CostGraph& cost_graph,
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ void AddMemoryTerms(
// can be a few (usually < 10) edges in the problem with negative costs. This
// is guaranteed to never produce a negative overall cost for the graph,
// however.
AutoShardingSolverResult CallORToolsSolver(
AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest(
const AutoShardingSolverRequest& unscaled_request) {
const absl::Time start_time = absl::Now();
const AutoShardingSolverRequest& request = ScaleRequest(unscaled_request);
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct AutoShardingSolverResult {
bool skip_auto_sharding;
};

AutoShardingSolverResult CallORToolsSolver(
AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest(
const AutoShardingSolverRequest& request);

enum AutoShardingViolationCode {
Expand Down
88 changes: 54 additions & 34 deletions xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,11 @@ AutoShardingSolverRequest AutoShardingSolverRequestWithEquivalences() {
return request;
}

TEST(CallORToolsSolverTest, SolvesOptimally) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOptimally) {
const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 7650.0;
Expand All @@ -262,12 +263,13 @@ TEST(CallORToolsSolverTest, SolvesOptimally) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, SolvesOverbudget) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOverbudget) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.set_memory_budget(100000);
request.mutable_overbudget_coeff()->set_coeff(10.0);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 9007650.0;
Expand All @@ -276,11 +278,12 @@ TEST(CallORToolsSolverTest, SolvesOverbudget) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, SolvesMaxDepartures) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesMaxDepartures) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.mutable_max_departures()->set_coeff(3.0);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const double objective_value = 7872.0;
Expand All @@ -289,11 +292,12 @@ TEST(CallORToolsSolverTest, SolvesMaxDepartures) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, MinimizesDepartures) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, MinimizesDepartures) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.set_minimize_departures(true);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 1, 0, 0, 1};
const double objective_value = 3.0;
Expand All @@ -302,13 +306,14 @@ TEST(CallORToolsSolverTest, MinimizesDepartures) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteNodeCosts) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.mutable_computation_costs(0)->set_costs(0, kInfinityCost);
request.mutable_computation_costs(0)->set_costs(1, kInfinityCost);
request.mutable_computation_costs(0)->set_costs(2, kInfinityCost);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {3, 0, 0, 0, 0};
const double objective_value = 10683.0;
Expand All @@ -317,11 +322,12 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteEdgeCosts) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.mutable_resharding_costs(0)->set_costs(0, kInfinityCost);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const double objective_value = 7872.0;
Expand All @@ -330,7 +336,7 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesFollowedEdges) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesFollowedEdges) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
AutoShardingSolverRequest_Pair edge;
edge.set_first(1);
Expand All @@ -346,7 +352,8 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) {
70000, 71000, 72000, 73000}};
AddCosts(request.mutable_duration_costs(), t);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 12650.0;
Expand All @@ -355,7 +362,7 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesCollapsedEdge) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesCollapsedEdge) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
AutoShardingSolverRequest_Pair edge;
edge.set_first(2);
Expand All @@ -373,7 +380,8 @@ TEST(CallORToolsSolverTest, HandlesCollapsedEdge) {
80000, 81000, 82000, 83000}};
AddCosts(request.mutable_duration_costs(), t);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const double objective_value = 13972.0;
Expand All @@ -382,12 +390,13 @@ TEST(CallORToolsSolverTest, HandlesCollapsedEdge) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, UsesHint) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, UsesHint) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
const auto s_hint = {1, 0, 0, 0, 0}; // Not optimal, but close.
request.mutable_s_hint()->Add(s_hint.begin(), s_hint.end());

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 7650.0;
Expand All @@ -396,20 +405,22 @@ TEST(CallORToolsSolverTest, UsesHint) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HonorsMaxCost) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, HonorsMaxCost) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.mutable_max_cost()->set_coeff(7600.0); // Best possible is 7650.0

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

EXPECT_TRUE(absl::IsInternal(result.status.status()));
}

TEST(CallORToolsSolverTest, HandlesExtremelyHighMaxCost) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesExtremelyHighMaxCost) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.mutable_max_cost()->set_coeff(1e19);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 7650.0;
Expand All @@ -418,7 +429,7 @@ TEST(CallORToolsSolverTest, HandlesExtremelyHighMaxCost) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
const EdgeMatrix live_edges = {{}, {0}, {0, 1}, {1}, {}};
const CostMatrix memory_edge_costs = {{1000000, 1100, 1200, 1300,
Expand All @@ -432,7 +443,8 @@ TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) {
AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs);
request.set_enable_memory_edge_costs(true);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const double objective_value = 7872.0;
Expand All @@ -441,7 +453,7 @@ TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesIntervals) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
const std::vector<std::pair<int64_t, int64_t>> node_intervals =
{{0, 4}, {0, 4}, {2, 3}, {3, 4}, {100, -1}};
Expand All @@ -460,7 +472,8 @@ TEST(CallORToolsSolverTest, HandlesIntervals) {
AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs);
request.set_enable_memory_edge_costs(true);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const double objective_value = 7872.0;
Expand All @@ -469,7 +482,8 @@ TEST(CallORToolsSolverTest, HandlesIntervals) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) {
TEST(FormulateAndSolveMIPFromSolverRequestTest,
HandlesReducedIntervalsAndGroups) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
const std::vector<std::pair<int64_t, int64_t>> node_intervals =
{{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}};
Expand All @@ -492,7 +506,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) {
AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs);
request.set_enable_memory_edge_costs(true);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const double objective_value = 7872.0;
Expand All @@ -501,7 +516,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) {
TEST(FormulateAndSolveMIPFromSolverRequestTest,
HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
const std::vector<std::pair<int64_t, int64_t>> node_intervals =
{{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}};
Expand All @@ -511,7 +527,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) {
AddGroups(request.mutable_node_groups(), node_groups);
request.set_enable_memory_edge_costs(false);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 7650.0;
Expand All @@ -520,7 +537,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) {
TEST(FormulateAndSolveMIPFromSolverRequestTest,
HandlesGroupsWithTinyMemoryCosts) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
const std::vector<std::pair<int64_t, int64_t>> node_intervals =
{{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}};
Expand Down Expand Up @@ -551,7 +569,8 @@ TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) {
request.set_enable_memory_edge_costs(true);
request.set_memory_budget(4321);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 7650.0;
Expand All @@ -560,11 +579,12 @@ TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, SolvesWithEquivalences) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesWithEquivalences) {
const AutoShardingSolverRequest request =
AutoShardingSolverRequestWithEquivalences();

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 5, 5, 1};
const double objective_value = 7650.0;
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace spmd {

// A wrapper around the solver that converts the given objects into a
// combinatorial optimization problem & solves it.
AutoShardingSolverResult CallSolver(
AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
Expand Down
Loading