diff --git a/pytype/typegraph/solver.cc b/pytype/typegraph/solver.cc index c1096d725..e176b25be 100644 --- a/pytype/typegraph/solver.cc +++ b/pytype/typegraph/solver.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -46,6 +47,84 @@ struct TraverseState { new_goals(std::move(new_goals)) {} }; +enum ActionType { + TRAVERSE, + INSERT_GOALS_TO_REMOVE, + ERASE_GOALS_TO_REMOVE, + ERASE_SEEN_GOALS, + ERASE_NEW_GOALS, + ERASE_REMOVED_GOALS, +}; + +// We're maintaining a state machine and actions to be able to do a DFS +// effectively. Rather than having to copy the states that are needed (which +// are four std::sets) whenever we need to traverse a new node, we keep track +// of the increment and the decrement between the previous node, and restore +// to the state of which were before after a node traversal, which is implement +// through "actions". +struct Action { + ActionType action_type; + // Goal either to delete are added to the corresponding set. + const Binding* goal; + // The iterator is for std::set and this is stable upon deletion and insertion + // if it's not directly the element being deleted or inserted. We will only + // try to erase the element on the exact node traversl, so we can safely + // reuse the iterator that was returned from the insertion. + // Not using this for action ERASE_GOALS_TO_REMOVE, as we are requesting + // for removal before the insertion has happened. + GoalSet::iterator erase_it; +}; + +static void traverse(const CFGNode* position, + std::vector& results, + std::stack& actions, TraverseState& state) { + if (state.goals_to_remove.empty()) { + results.emplace_back(GoalSet(state.removed_goals), + GoalSet(state.new_goals)); + return; + } + + const Binding* goal = *state.goals_to_remove.begin(); + state.goals_to_remove.erase(state.goals_to_remove.begin()); + actions.emplace(INSERT_GOALS_TO_REMOVE, goal); + + if (state.seen_goals.count(goal)) { + // Only process a goal once, to prevent infinite loops. + actions.emplace(TRAVERSE, nullptr); + return; + } + auto [it, added] = state.seen_goals.insert(goal); + actions.emplace(ERASE_SEEN_GOALS, nullptr, it); + + const auto* origin = goal->FindOrigin(position); + if (!origin) { + std::tie(it, added) = state.new_goals.insert(goal); + if (added) { + actions.emplace(ERASE_NEW_GOALS, nullptr, it); + } + actions.emplace(TRAVERSE, nullptr); + return; + } + + std::tie(it, added) = state.removed_goals.insert(goal); + if (added) { + actions.emplace(ERASE_REMOVED_GOALS, nullptr, it); + } + for (const auto& source_set : origin->source_sets) { + for (const Binding* next_goal : source_set) { + if (!state.goals_to_remove.count(next_goal)) { + actions.emplace(ERASE_GOALS_TO_REMOVE, next_goal); + } + } + actions.emplace(TRAVERSE, nullptr); + for (const Binding* next_goal : source_set) { + if (!state.goals_to_remove.count(next_goal)) { + actions.emplace(INSERT_GOALS_TO_REMOVE, next_goal); + } + } + } +} + // Remove all goals that can be fulfilled at the current CFG node. // Generates all possible sets of new goals obtained by replacing a goal that // originates at the current node with one of its source sets, iteratively, @@ -65,37 +144,31 @@ static std::vector remove_finished_goals(const CFGNode* pos, state.goals_to_remove.end(), std::inserter(state.new_goals, state.new_goals.begin()), pointer_less()); - std::deque queue; - queue.push_back(std::move(state)); + std::stack actions; + actions.emplace(TRAVERSE, nullptr); std::vector results; - while (!queue.empty()) { - state = std::move(queue.front()); - queue.pop_front(); - if (state.goals_to_remove.empty()) { - results.push_back(RemoveResult(state.removed_goals, state.new_goals)); - continue; - } - const auto* goal = *state.goals_to_remove.begin(); - state.goals_to_remove.erase(state.goals_to_remove.begin()); - if (state.seen_goals.count(goal)) { - // Only process a goal once, to prevent infinite loops. - queue.emplace_back(std::move(state)); - continue; - } - state.seen_goals.insert(goal); - const auto* origin = goal->FindOrigin(pos); - if (!origin) { - state.new_goals.insert(goal); - queue.emplace_back(std::move(state)); - continue; - } - state.removed_goals.insert(goal); - for (const auto& source_set : origin->source_sets) { - GoalSet next_goals_to_remove(state.goals_to_remove); - next_goals_to_remove.insert(source_set.begin(), source_set.end()); - queue.push_back(TraverseState(std::move(next_goals_to_remove), - state.seen_goals, state.removed_goals, - state.new_goals)); + while (!actions.empty()) { + Action action = actions.top(); + actions.pop(); + switch (action.action_type) { + case TRAVERSE: + traverse(pos, results, actions, state); + break; + case INSERT_GOALS_TO_REMOVE: + state.goals_to_remove.insert(action.goal); + break; + case ERASE_GOALS_TO_REMOVE: + state.goals_to_remove.erase(action.goal); + break; + case ERASE_SEEN_GOALS: + state.seen_goals.erase(action.erase_it); + break; + case ERASE_NEW_GOALS: + state.new_goals.erase(action.erase_it); + break; + case ERASE_REMOVED_GOALS: + state.removed_goals.erase(action.erase_it); + break; } } return results;