Skip to content

Commit

Permalink
arith priv/dio slv/congr mng: Refactor to not use NodeManager::curren…
Browse files Browse the repository at this point in the history
…tNM() (cvc5#11222)
  • Loading branch information
daniel-larraz authored Sep 25, 2024
1 parent e7f25fc commit 6261f32
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 49 deletions.
13 changes: 6 additions & 7 deletions src/theory/arith/linear/congruence_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ namespace cvc5::internal {
namespace theory {
namespace arith::linear {

std::vector<Node> andComponents(TNode an)
std::vector<Node> andComponents(NodeManager* nm, TNode an)
{
auto nm = NodeManager::currentNM();
if (an == nm->mkConst(true))
{
return {};
Expand Down Expand Up @@ -278,7 +277,7 @@ void ArithCongruenceManager::watchedVariableCannotBeZero(ConstraintCP c){
TNode isZero = d_watchedEqualities[s];
TypeNode type = isZero[0].getType();
const auto isZeroPf = d_pnm->mkAssume(isZero);
const auto nm = NodeManager::currentNM();
const auto nm = nodeManager();
std::vector<std::shared_ptr<ProofNode>> pfs{isZeroPf, pf};
// Trick for getting correct, opposing signs.
std::vector<Node> coeff{nm->mkConstInt(Rational(-1 * cSign)),
Expand Down Expand Up @@ -370,7 +369,7 @@ bool ArithCongruenceManager::propagate(TNode x){
// we have a proof of (=> C L1) and need a proof of
// (not (and C L2)), where L1 and L2 are contradictory literals,
// stored in proven[1] and neg respectively below.
NodeManager* nm = NodeManager::currentNM();
NodeManager* nm = nodeManager();
std::vector<Node> conj(finalPf.begin(), finalPf.end());
CDProof cdp(d_env);
Node falsen = nm->mkConst(false);
Expand Down Expand Up @@ -494,7 +493,7 @@ TrustNode ArithCongruenceManager::explain(TNode external)
Trace("arith-ee") << "tweaking proof to prove " << external << " not "
<< trn.getProven()[1] << std::endl;
std::vector<std::shared_ptr<ProofNode>> assumptionPfs;
std::vector<Node> assumptions = andComponents(trn.getNode());
std::vector<Node> assumptions = andComponents(nodeManager(), trn.getNode());
assumptionPfs.push_back(trn.toProofNode());
for (const auto& a : assumptions)
{
Expand Down Expand Up @@ -615,7 +614,7 @@ void ArithCongruenceManager::equalsConstant(ConstraintCP c){

ArithVar x = c->getVariable();
Node xAsNode = d_avariables.asNode(x);
NodeManager* nm = NodeManager::currentNM();
NodeManager* nm = nodeManager();
Node asRational = nm->mkConstRealOrInt(
xAsNode.getType(), c->getValue().getNoninfinitesimalPart());

Expand Down Expand Up @@ -649,7 +648,7 @@ void ArithCongruenceManager::equalsConstant(ConstraintCP lb, ConstraintCP ub){
Node reason = mkAndFromBuilder(nb);

Node xAsNode = d_avariables.asNode(x);
NodeManager* nm = NodeManager::currentNM();
NodeManager* nm = nodeManager();
Node asRational = nm->mkConstRealOrInt(
xAsNode.getType(), lb->getValue().getNoninfinitesimalPart());

Expand Down
16 changes: 7 additions & 9 deletions src/theory/arith/linear/dio_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ namespace cvc5::internal {
namespace theory {
namespace arith::linear {

inline Node makeIntegerVariable(){
NodeManager* nm = NodeManager::currentNM();
inline Node makeIntegerVariable(NodeManager* nm)
{
SkolemManager* sm = nm->getSkolemManager();
return sm->mkDummySkolem("intvar",
nm->integerType(),
Expand Down Expand Up @@ -87,7 +87,7 @@ size_t DioSolver::allocateProofVariable() {
Assert(d_lastUsedProofVariable <= d_proofVariablePool.size());
if(d_lastUsedProofVariable == d_proofVariablePool.size()){
Assert(d_lastUsedProofVariable == d_proofVariablePool.size());
Node intVar = makeIntegerVariable();
Node intVar = makeIntegerVariable(nodeManager());
d_proofVariablePool.push_back(Variable(intVar));
}
size_t res = d_lastUsedProofVariable;
Expand All @@ -108,8 +108,7 @@ Node DioSolver::nextPureSubstitution(){
Polynomial p = sp.getPolynomial();
Constant c = -sp.getConstant();
Polynomial cancelV = p + Polynomial::mkPolynomial(v);
Node eq = NodeManager::currentNM()->mkNode(
Kind::EQUAL, v.getNode(), cancelV.getNode());
Node eq = nodeManager()->mkNode(Kind::EQUAL, v.getNode(), cancelV.getNode());
return eq;
}

Expand Down Expand Up @@ -146,7 +145,7 @@ void DioSolver::pushInputConstraint(const Comparison& eq, Node reason){

size_t varIndex = allocateProofVariable();
Variable proofVariable(d_proofVariablePool[varIndex]);
//Variable proofVariable(makeIntegerVariable());
// Variable proofVariable(makeIntegerVariable(nodeManager()));

TrailIndex posInTrail = d_trail.size();
Trace("dio::pushInputConstraint") << "pushInputConstraint @ " << posInTrail
Expand Down Expand Up @@ -678,7 +677,7 @@ std::pair<DioSolver::SubIndex, DioSolver::TrailIndex> DioSolver::decomposeIndex(
Assert(q.getPolynomial().getCoefficient(vl) == Constant::mkConstant(1));

Assert(!r.isZero());
Node freshNode = makeIntegerVariable();
Node freshNode = makeIntegerVariable(nodeManager());
Variable fresh(freshNode);
SumPair fresh_one=SumPair::mkSumPair(fresh);
SumPair fresh_a = fresh_one * a;
Expand Down Expand Up @@ -820,8 +819,7 @@ void DioSolver::addTrailElementAsLemma(TrailIndex i) {
Node DioSolver::trailIndexToEquality(TrailIndex i) const {
const SumPair& sp = d_trail[i].d_eq;
Node n = sp.getNode();
Node zero =
NodeManager::currentNM()->mkConstRealOrInt(n.getType(), Rational(0));
Node zero = nodeManager()->mkConstRealOrInt(n.getType(), Rational(0));
Node eq = n.eqNode(zero);
return eq;
}
Expand Down
64 changes: 32 additions & 32 deletions src/theory/arith/linear/theory_arith_private.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ namespace cvc5::internal {
namespace theory {
namespace arith::linear {

static Node toSumNode(const ArithVariables& vars, const DenseMap<Rational>& sum);
static Node toSumNode(NodeManager* nm,
const ArithVariables& vars,
const DenseMap<Rational>& sum);
static bool complexityBelow(const DenseMap<Rational>& row, uint32_t cap);

TheoryArithPrivate::TheoryArithPrivate(Env& env,
Expand Down Expand Up @@ -935,7 +937,7 @@ Node TheoryArithPrivate::getCandidateModelValue(TNode term)
const DeltaRational drv = getDeltaValue(term);
const Rational& delta = d_partialModel.getDelta();
const Rational qmodel = drv.substituteDelta( delta );
return NodeManager::currentNM()->mkConstRealOrInt(term.getType(), qmodel);
return nodeManager()->mkConstRealOrInt(term.getType(), qmodel);
} catch (DeltaRationalException& dr) {
return Node::null();
} catch (ModelException& me) {
Expand Down Expand Up @@ -985,7 +987,7 @@ Theory::PPAssertStatus TheoryArithPrivate::ppAssert(
Assert(elim == rewrite(elim));
if (elim.getType().isInteger() && !minVar.getType().isInteger())
{
elim = NodeManager::currentNM()->mkNode(Kind::TO_REAL, elim);
elim = nodeManager()->mkNode(Kind::TO_REAL, elim);
}
if (right.size() > options().arith.ppAssertMaxSubSize)
{
Expand Down Expand Up @@ -1393,8 +1395,7 @@ TrustNode TheoryArithPrivate::dioCutting()
Assert(!gcd.divides(c.asConstant().getNumerator()));
Comparison leq = Comparison::mkComparison(Kind::LEQ, p, c);
Comparison geq = Comparison::mkComparison(Kind::GEQ, p, c);
Node lemma = NodeManager::currentNM()->mkNode(
Kind::OR, leq.getNode(), geq.getNode());
Node lemma = nodeManager()->mkNode(Kind::OR, leq.getNode(), geq.getNode());
Node rewrittenLemma = rewrite(lemma);
Trace("arith::dio::ex") << "dioCutting found the plane: " << plane.getNode() << endl;
Trace("arith::dio::ex") << "resulting in the cut: " << lemma << endl;
Expand All @@ -1404,7 +1405,7 @@ TrustNode TheoryArithPrivate::dioCutting()
Trace("arith::dio") << "rewritten " << rewrittenLemma << endl;
if (proofsEnabled())
{
NodeManager* nm = NodeManager::currentNM();
NodeManager* nm = nodeManager();
Node gt = nm->mkNode(Kind::GT, p.getNode(), c.getNode());
Node lt = nm->mkNode(Kind::LT, p.getNode(), c.getNode());
TypeNode type = gt[0].getType();
Expand Down Expand Up @@ -1683,7 +1684,8 @@ bool TheoryArithPrivate::hasIntegerModel()
}
}

Node flattenAndSort(Node n){
Node flattenAndSort(NodeManager* nm, Node n)
{
Kind k = n.getKind();
switch(k){
case Kind::OR:
Expand All @@ -1709,11 +1711,9 @@ Node flattenAndSort(Node n){
}
Assert(out.size() >= 2);
std::sort(out.begin(), out.end());
return NodeManager::currentNM()->mkNode(k, out);
return nm->mkNode(k, out);
}



/** Outputs conflicts to the output channel. */
void TheoryArithPrivate::outputConflicts(){
Trace("arith::conflict") << "outputting conflicts" << std::endl;
Expand Down Expand Up @@ -1746,7 +1746,7 @@ void TheoryArithPrivate::outputConflicts(){
<< "d_conflicts[" << i << "] " << conflict
<< " has proof: " << hasProof << ", id = " << conf.second << endl;
if(TraceIsOn("arith::normalize::external")){
conflict = flattenAndSort(conflict);
conflict = flattenAndSort(nodeManager(), conflict);
Trace("arith::conflict") << "(normalized to) " << conflict << endl;
}

Expand All @@ -1765,7 +1765,7 @@ void TheoryArithPrivate::outputConflicts(){
Trace("arith::conflict") << "black box conflict" << bb
<< endl;
if(TraceIsOn("arith::normalize::external")){
bb = flattenAndSort(bb);
bb = flattenAndSort(nodeManager(), bb);
Trace("arith::conflict") << "(normalized to) " << bb << endl;
}
if (isProofEnabled() && d_blackBoxConflictPf.get())
Expand Down Expand Up @@ -1822,7 +1822,7 @@ void TheoryArithPrivate::outputPropagate(TNode lit) {

void TheoryArithPrivate::outputRestart() {
Trace("arith::channel") << "Arith restart!" << std::endl;
NodeManager* nm = NodeManager::currentNM();
NodeManager* nm = nodeManager();
SkolemManager* sm = nm->getSkolemManager();
Node restartVar = sm->mkDummySkolem(
"restartVar",
Expand Down Expand Up @@ -1960,7 +1960,7 @@ bool TheoryArithPrivate::replayLog(ApproximateSimplex* approx){
std::pair<ConstraintP, ArithVar> TheoryArithPrivate::replayGetConstraint(const DenseMap<Rational>& lhs, Kind k, const Rational& rhs, bool branch)
{
ArithVar added = ARITHVAR_SENTINEL;
Node sum = toSumNode(d_partialModel, lhs);
Node sum = toSumNode(nodeManager(), d_partialModel, lhs);
if(sum.isNull()){ return make_pair(NullConstraint, added); }

Trace("approx::constraint") << "replayGetConstraint " << sum
Expand All @@ -1970,7 +1970,7 @@ std::pair<ConstraintP, ArithVar> TheoryArithPrivate::replayGetConstraint(const D

Assert(k == Kind::LEQ || k == Kind::GEQ);

NodeManager* nm = NodeManager::currentNM();
NodeManager* nm = nodeManager();
Node comparison =
nm->mkNode(k, sum, nm->mkConstRealOrInt(sum.getType(), rhs));
Node rewritten = rewrite(comparison);
Expand Down Expand Up @@ -2071,9 +2071,11 @@ std::pair<ConstraintP, ArithVar> TheoryArithPrivate::replayGetConstraint(const C
return replayGetConstraint(lhs, k, rhs, ci.getKlass() == BranchCutKlass);
}

Node toSumNode(const ArithVariables& vars, const DenseMap<Rational>& sum){
Node toSumNode(NodeManager* nm,
const ArithVariables& vars,
const DenseMap<Rational>& sum)
{
Trace("arith::toSumNode") << "toSumNode() begin" << endl;
NodeManager* nm = NodeManager::currentNM();
DenseMap<Rational>::const_iterator iter, end;
iter = sum.begin(), end = sum.end();
std::vector<Node> children;
Expand Down Expand Up @@ -2595,7 +2597,7 @@ Node TheoryArithPrivate::branchToNode(ApproximateSimplex* approx,
return Node::null();
}
Rational fl(maybe_value.value().floor());
NodeManager* nm = NodeManager::currentNM();
NodeManager* nm = nodeManager();
Node leq =
nm->mkNode(Kind::LEQ, n, nm->mkConstRealOrInt(n.getType(), fl));
Node norm = rewrite(leq);
Expand All @@ -2609,9 +2611,9 @@ Node TheoryArithPrivate::cutToLiteral(ApproximateSimplex* approx, const CutInfo&
Assert(ci.reconstructed());

const DenseMap<Rational>& lhs = ci.getReconstruction().lhs;
Node sum = toSumNode(d_partialModel, lhs);
Node sum = toSumNode(nodeManager(), d_partialModel, lhs);
if(!sum.isNull()){
NodeManager* nm = NodeManager::currentNM();
NodeManager* nm = nodeManager();
Kind k = ci.getKind();
Assert(k == Kind::LEQ || k == Kind::GEQ);
Node rhs = nm->mkConstRealOrInt(sum.getType(), ci.getReconstruction().rhs);
Expand Down Expand Up @@ -3708,7 +3710,7 @@ void TheoryArithPrivate::propagate(Theory::Effort e) {
}
}

NodeManager* nm = NodeManager::currentNM();
NodeManager* nm = nodeManager();
while(d_congruenceManager.hasMorePropagations()){
TNode toProp = d_congruenceManager.getNextPropagation();

Expand Down Expand Up @@ -3941,7 +3943,7 @@ void TheoryArithPrivate::collectModelValues(
// TODO:
// This is not very good for user push/pop....
// Revisit when implementing push/pop
NodeManager* nm = NodeManager::currentNM();
NodeManager* nm = nodeManager();
for(var_iterator vi = var_begin(), vend = var_end(); vi != vend; ++vi){
ArithVar v = *vi;

Expand Down Expand Up @@ -4543,7 +4545,7 @@ bool TheoryArithPrivate::rowImplicationCanBeApplied(RowIndex ridx, bool rowUp, C
// Collect the farkas coefficients, as nodes.
std::vector<Node> farkasCoefficients;
farkasCoefficients.reserve(coeffs->size());
auto nm = NodeManager::currentNM();
auto nm = nodeManager();
std::transform(
coeffs->begin(),
coeffs->end(),
Expand Down Expand Up @@ -4873,10 +4875,8 @@ std::pair<bool, Node> TheoryArithPrivate::entailmentCheck(TNode lit)
return make_pair(false, Node::null());
}

bool TheoryArithPrivate::decomposeTerm(Node t,
Rational& m,
Node& p,
Rational& c)
bool TheoryArithPrivate::decomposeTerm(
NodeManager* nm, Node t, Rational& m, Node& p, Rational& c)
{
if(!Polynomial::isMember(t)){
return false;
Expand All @@ -4891,7 +4891,7 @@ bool TheoryArithPrivate::decomposeTerm(Node t,
Polynomial poly = Polynomial::parsePolynomial(t);
if(poly.isConstant()){
c = poly.getHead().getConstant().getValue();
p = NodeManager::currentNM()->mkConstReal(Rational(0));
p = nm->mkConstReal(Rational(0));
m = Rational(1);
return true;
}else if(poly.containsConstant()){
Expand Down Expand Up @@ -4956,14 +4956,14 @@ bool TheoryArithPrivate::decomposeLiteral(Node lit, Kind& k, int& dir, Rational&
// left : lm*( lp ) + lc
// right: rm*( rp ) + rc
Rational lc, rc;
bool success = decomposeTerm(rewrite(left), lm, lp, lc);
bool success = decomposeTerm(nodeManager(), rewrite(left), lm, lp, lc);
if(!success){ return false; }
success = decomposeTerm(rewrite(right), rm, rp, rc);
success = decomposeTerm(nodeManager(), rewrite(right), rm, rp, rc);
if(!success){ return false; }

Node diff = rewrite(NodeManager::currentNM()->mkNode(Kind::SUB, left, right));
Node diff = rewrite(nodeManager()->mkNode(Kind::SUB, left, right));
Rational dc;
success = decomposeTerm(diff, dm, dp, dc);
success = decomposeTerm(nodeManager(), diff, dm, dp, dc);
// can occur in entailment tests involving ITE terms
if (!success)
{
Expand Down
3 changes: 2 additions & 1 deletion src/theory/arith/linear/theory_arith_private.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ class TheoryArithPrivate : protected EnvObj
//std::pair<DeltaRational, Node> inferBound(TNode term, bool lb, int maxRounds = -1, const DeltaRational* threshold = NULL);

private:
static bool decomposeTerm(Node t, Rational& m, Node& p, Rational& c);
static bool decomposeTerm(
NodeManager* nm, Node t, Rational& m, Node& p, Rational& c);
bool decomposeLiteral(Node lit,
Kind& k,
int& dir,
Expand Down

0 comments on commit 6261f32

Please sign in to comment.