From 8f4d7925dd48513880ab5b53c235d2b1ca0abef1 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Mon, 22 Apr 2024 10:49:04 -0500 Subject: [PATCH 1/6] Fix RARE fixed point matching under contexts (#10643) The implementation of proof reconstruction for rules with fixed-point with contexts had several bugs and would never be used in final proofs. This PR corrects the issues. --- src/expr/node_algorithm.cpp | 2 +- src/rewriter/rewrite_db_proof_cons.cpp | 52 ++++++++++++++++++++------ 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/src/expr/node_algorithm.cpp b/src/expr/node_algorithm.cpp index b114c0c98fd..affea6d1cb8 100644 --- a/src/expr/node_algorithm.cpp +++ b/src/expr/node_algorithm.cpp @@ -704,7 +704,7 @@ bool match(Node x, Node y, std::unordered_map& subs) visited.insert(curr); if (curr.first.getNumChildren() == 0) { - if (curr.first.getType() != curr.second.getType()) + if (!curr.first.getType().isComparableTo(curr.second.getType())) { // the two subterms have different types return false; diff --git a/src/rewriter/rewrite_db_proof_cons.cpp b/src/rewriter/rewrite_db_proof_cons.cpp index 32b6ac02264..21ecc6aa7a8 100644 --- a/src/rewriter/rewrite_db_proof_cons.cpp +++ b/src/rewriter/rewrite_db_proof_cons.cpp @@ -258,7 +258,7 @@ bool RewriteDbProofCons::notifyMatch(const Node& s, << std::endl; const RewriteProofRule& rpr = d_db->getRule(d_currFixedPointId); // get the conclusion - Node target = rpr.getConclusion(); + Node target = rpr.getConclusion(true); // apply substitution, which may notice vars may be out of order wrt rule // var list target = expr::narySubstitute(target, vars, subs); @@ -323,7 +323,10 @@ bool RewriteDbProofCons::proveWithRule(RewriteProofStatus id, ProofRewriteRule r) { Assert(!target.isNull() && target.getKind() == Kind::EQUAL); - Trace("rpc-debug2") << "Check rule " << id << std::endl; + Trace("rpc-debug2") << "Check rule " + << (id == RewriteProofStatus::DSL ? toString(r) + : toString(id)) + << std::endl; std::vector vcs; Node transEq; ProvenInfo pic; @@ -950,13 +953,19 @@ Node RewriteDbProofCons::getRuleConclusion(const RewriteProofRule& rpr, { pi.d_id = RewriteProofStatus::DSL; pi.d_dslId = rpr.getId(); - Node conc = rpr.getConclusion(); + Node conc = rpr.getConclusion(true); + Node concRhs = conc[1]; + Trace("rpc-ctx") << "***GET CONCLUSION " << pi.d_dslId << " for " << vars + << " -> " << subs << std::endl; // if fixed point, we continue applying if (doFixedPoint && rpr.isFixedPoint()) { Assert(d_currFixedPointId == ProofRewriteRule::NONE); Assert(d_currFixedPointConc.isNull()); d_currFixedPointId = rpr.getId(); + Node context = rpr.getContext(); + Assert(!context.isNull()); + Trace("rpc-ctx") << "Context is " << context << std::endl; // check if stgt also rewrites with the same rule? bool continueFixedPoint; std::vector steps; @@ -968,24 +977,42 @@ Node RewriteDbProofCons::getRuleConclusion(const RewriteProofRule& rpr, Node stgt = ssrc; do { + Trace("rpc-ctx") << "Get matches " << stgt << std::endl; + Trace("rpc-ctx") << "Conclusion is " << concRhs << std::endl; continueFixedPoint = false; rpr.getMatches(stgt, &d_notify); + Trace("rpc-ctx") << "...conclusion is " << d_currFixedPointConc + << std::endl; if (!d_currFixedPointConc.isNull()) { // currently avoid accidental loops: arbitrarily bound to 1000 continueFixedPoint = steps.size() <= s_fixedPointLimit; Assert(d_currFixedPointConc.getKind() == Kind::EQUAL); - steps.push_back(d_currFixedPointConc[1]); stepsSubs.emplace_back(d_currFixedPointSubs.begin(), d_currFixedPointSubs.end()); stgt = d_currFixedPointConc[1]; + // For example, we have now computed + // (str.len (str.++ x y z)) --> (+ (str.len x) (str.len (str.++ y z))) + // where stgt is the RHS. We now want to continue to rewrite the right + // hand side, where to find the next target term to rewrite, we match + // the context of the rule to the RHS. + // In particular, given a context (+ (str.len S0) ?0), we would match + // ?0 to (str.len (str.++ y z)). This indicates that the user suggests + // that (str.len (str.++ y z)) is the term to continue to rewrite. + // We update stgt to this term to proceed with the loop. + std::unordered_map msubs; + expr::match(context[1], stgt, msubs); + Trace("rpc-ctx") << "Matching context " << context << " with " << stgt + << " gives " << msubs[context[0][0]] << std::endl; + stgt = msubs[context[0][0]]; + Assert(!stgt.isNull()); + steps.push_back(stgt); } d_currFixedPointConc = Node::null(); } while (continueFixedPoint); std::vector transEq; Node prev = ssrc; - Node context = rpr.getContext(); Node placeholder = context[0][0]; Node body = context[1]; Node currConc = body; @@ -998,6 +1025,9 @@ Node RewriteDbProofCons::getRuleConclusion(const RewriteProofRule& rpr, Node target = expr::narySubstitute(body, vars, stepSubs); target = target.substitute(TNode(placeholder), TNode(step)); cacheProofSubPlaceholder(currContext, placeholder, source, target); + Trace("rpc-ctx") << "Step " << source << " == " << target << " from " + << body << " " << vars << " -> " << stepSubs << ", " + << placeholder << " -> " << step << std::endl; ProvenInfo& dpi = d_pcache[source.eqNode(target)]; dpi.d_id = pi.d_id; @@ -1023,19 +1053,16 @@ Node RewriteDbProofCons::getRuleConclusion(const RewriteProofRule& rpr, pi.d_id = RewriteProofStatus::TRANS; // store transEq in d_vars pi.d_vars = transEq; + Trace("rpc-ctx") << "***RETURN trans " << transEq.back()[1] << std::endl; // return the end of the chain, which will be used for constrained // matching return transEq.back()[1]; } } - Node res = conc[1]; - if (rpr.isFixedPoint()) - { - Node context = rpr.getContext(); - res = context[1].substitute(TNode(context[0][0]), TNode(conc[1])); - } - return expr::narySubstitute(res, vars, subs); + Node ret = expr::narySubstitute(concRhs, vars, subs); + Trace("rpc-ctx") << "***RETURN " << ret << std::endl; + return ret; } void RewriteDbProofCons::cacheProofSubPlaceholder(TNode context, @@ -1079,6 +1106,7 @@ void RewriteDbProofCons::cacheProofSubPlaceholder(TNode context, { ProvenInfo& cpi = d_pcache[cong]; cpi.d_id = RewriteProofStatus::CONG; + cpi.d_vars.clear(); for (size_t i = 0, size = cong[0].getNumChildren(); i < size; i++) { TNode lhs = cong[0][i]; From 0202c0d9ddb44e1b05bc914230683dc2ff16d23a Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Mon, 22 Apr 2024 11:14:51 -0500 Subject: [PATCH 2/6] Add proofs for static rewrite preprocessing pass (#10657) Required for narrowing down inputs to the DSL reconstruction. Reduces the number of preprocessing holes from 343 -> 330 on my dev branch. --- src/preprocessing/passes/static_rewrite.cpp | 28 +++++++++++++++++---- src/preprocessing/passes/static_rewrite.h | 5 ++++ src/proof/trust_id.cpp | 1 + src/proof/trust_id.h | 2 ++ 4 files changed, 31 insertions(+), 5 deletions(-) diff --git a/src/preprocessing/passes/static_rewrite.cpp b/src/preprocessing/passes/static_rewrite.cpp index 8d696f71b18..23ac2674e2c 100644 --- a/src/preprocessing/passes/static_rewrite.cpp +++ b/src/preprocessing/passes/static_rewrite.cpp @@ -15,6 +15,7 @@ #include "preprocessing/passes/static_rewrite.h" +#include "options/smt_options.h" #include "preprocessing/assertion_pipeline.h" #include "preprocessing/preprocessing_pass_context.h" #include "theory/theory_engine.h" @@ -25,9 +26,18 @@ namespace cvc5::internal { namespace preprocessing { namespace passes { -StaticRewrite::StaticRewrite( - PreprocessingPassContext* preprocContext) - : PreprocessingPass(preprocContext, "static-rewrite"){}; +StaticRewrite::StaticRewrite(PreprocessingPassContext* preprocContext) + : PreprocessingPass(preprocContext, "static-rewrite") +{ + if (options().smt.produceProofs) + { + d_tpg.reset(new TConvProofGenerator(d_env, + userContext(), + TConvPolicy::FIXPOINT, + TConvCachePolicy::NEVER, + "StaticRewrite::tpg")); + } +} PreprocessingPassResult StaticRewrite::applyInternal( AssertionPipeline* assertions) @@ -125,6 +135,14 @@ TrustNode StaticRewrite::rewriteAssertion(TNode n) rewrittenTo[cur] = retr; rewrittenTo[ret] = retr; visit.push_back(retr); + if (d_tpg != nullptr) + { + d_tpg->addRewriteStep(ret, + trn.getNode(), + trn.getGenerator(), + false, + TrustId::PP_STATIC_REWRITE); + } } if (!wasRewritten) { @@ -144,8 +162,8 @@ TrustNode StaticRewrite::rewriteAssertion(TNode n) { return TrustNode::null(); } - // can make proof producing by providing a term conversion generator here - return TrustNode::mkTrustRewrite(n, ret, nullptr); + // use the term conversion proof generator if it exists + return TrustNode::mkTrustRewrite(n, ret, d_tpg.get()); } } // namespace passes diff --git a/src/preprocessing/passes/static_rewrite.h b/src/preprocessing/passes/static_rewrite.h index 9f1df9398a7..1b408e17078 100644 --- a/src/preprocessing/passes/static_rewrite.h +++ b/src/preprocessing/passes/static_rewrite.h @@ -20,6 +20,7 @@ #include "expr/node.h" #include "preprocessing/preprocessing_pass.h" +#include "proof/conv_proof_generator.h" #include "proof/trust_node.h" namespace cvc5::internal { @@ -48,6 +49,10 @@ class StaticRewrite : public PreprocessingPass * Returns the trust node corresponding to the rewrite. */ TrustNode rewriteAssertion(TNode assertion); + + private: + /** A term conversion proof generator */ + std::unique_ptr d_tpg; }; } // namespace passes diff --git a/src/proof/trust_id.cpp b/src/proof/trust_id.cpp index 42a2bdb6692..1fb24d9f13b 100644 --- a/src/proof/trust_id.cpp +++ b/src/proof/trust_id.cpp @@ -31,6 +31,7 @@ const char* toString(TrustId id) case TrustId::THEORY_INFERENCE: return "THEORY_INFERENCE"; case TrustId::PREPROCESS: return "PREPROCESS"; case TrustId::PREPROCESS_LEMMA: return "PREPROCESS_LEMMA"; + case TrustId::PP_STATIC_REWRITE: return "PP_STATIC_REWRITE"; case TrustId::THEORY_PREPROCESS: return "THEORY_PREPROCESS"; case TrustId::THEORY_PREPROCESS_LEMMA: return "THEORY_PREPROCESS_LEMMA"; case TrustId::THEORY_EXPAND_DEF: return "THEORY_EXPAND_DEF"; diff --git a/src/proof/trust_id.h b/src/proof/trust_id.h index 8fc66413bef..c4068946833 100644 --- a/src/proof/trust_id.h +++ b/src/proof/trust_id.h @@ -36,6 +36,8 @@ enum class TrustId : uint32_t PREPROCESS, /** A lemma added during preprocessing without a proof */ PREPROCESS_LEMMA, + /** A ppStaticRewrite step */ + PP_STATIC_REWRITE, /** A rewrite of the input formula made by a theory during preprocessing without a proof */ THEORY_PREPROCESS, From 776a08459d051210af0a8f6fc7b2ae9edbd5c142 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Mon, 22 Apr 2024 12:22:37 -0500 Subject: [PATCH 3/6] Add conversion method for RARE in ALF involving `:list` variables (#10617) Towards the integration of RARE+ALF. It also eliminates a getNullTerminator utility that is now obsolete after nil terminators are perfectly reflected in ALF. --- src/CMakeLists.txt | 2 + src/proof/alf/alf_list_node_converter.cpp | 77 +++++++++++++++++++++ src/proof/alf/alf_list_node_converter.h | 82 +++++++++++++++++++++++ src/proof/alf/alf_node_converter.cpp | 39 ----------- src/proof/alf/alf_node_converter.h | 10 --- 5 files changed, 161 insertions(+), 49 deletions(-) create mode 100644 src/proof/alf/alf_list_node_converter.cpp create mode 100644 src/proof/alf/alf_list_node_converter.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 87bb5726b3d..6fe763e7306 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -148,6 +148,8 @@ libcvc5_add_sources( printer/smt2/smt2_printer.h proof/alf/alf_dependent_type_converter.cpp proof/alf/alf_dependent_type_converter.h + proof/alf/alf_list_node_converter.cpp + proof/alf/alf_list_node_converter.h proof/alf/alf_node_converter.cpp proof/alf/alf_node_converter.h proof/alf/alf_print_channel.cpp diff --git a/src/proof/alf/alf_list_node_converter.cpp b/src/proof/alf/alf_list_node_converter.cpp new file mode 100644 index 00000000000..717f3f5cc51 --- /dev/null +++ b/src/proof/alf/alf_list_node_converter.cpp @@ -0,0 +1,77 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2023 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Implementation of ALF node conversion for list variables in DSL rules + */ + +#include "proof/alf/alf_list_node_converter.h" + +#include "expr/nary_term_util.h" +#include "printer/printer.h" +#include "printer/smt2/smt2_printer.h" + +namespace cvc5::internal { +namespace proof { + +AlfListNodeConverter::AlfListNodeConverter(NodeManager* nm, + BaseAlfNodeConverter& tproc) + : NodeConverter(nm), d_tproc(tproc) +{ +} + +Node AlfListNodeConverter::postConvert(Node n) +{ + Kind k = n.getKind(); + switch (k) + { + case Kind::STRING_CONCAT: + case Kind::BITVECTOR_ADD: + case Kind::BITVECTOR_MULT: + case Kind::BITVECTOR_AND: + case Kind::BITVECTOR_OR: + case Kind::BITVECTOR_XOR: + case Kind::FINITE_FIELD_ADD: + case Kind::FINITE_FIELD_MULT: + case Kind::OR: + case Kind::AND: + case Kind::SEP_STAR: + case Kind::ADD: + case Kind::MULT: + case Kind::NONLINEAR_MULT: + case Kind::BITVECTOR_CONCAT: + case Kind::REGEXP_CONCAT: + case Kind::REGEXP_UNION: + case Kind::REGEXP_INTER: + // operators with a ground null terminator + break; + default: + // not an n-ary kind + return n; + } + size_t nlistChildren = 0; + for (const Node& nc : n) + { + if (!expr::isListVar(nc)) + { + nlistChildren++; + } + } + // if less than 2 non-list children, it might collapse to a single element + if (nlistChildren < 2) + { + return d_tproc.mkInternalApp("$dsl.singleton_elim", {n}, n.getType()); + } + return n; +} + +} // namespace proof +} // namespace cvc5::internal diff --git a/src/proof/alf/alf_list_node_converter.h b/src/proof/alf/alf_list_node_converter.h new file mode 100644 index 00000000000..d152b323970 --- /dev/null +++ b/src/proof/alf/alf_list_node_converter.h @@ -0,0 +1,82 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2023 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Implementation of ALF node conversion for list variables in DSL rules + */ +#include "cvc5_private.h" + +#ifndef CVC4__PROOF__ALF__ALF_LIST_NODE_CONVERTER_H +#define CVC4__PROOF__ALF__ALF_LIST_NODE_CONVERTER_H + +#include "expr/node_converter.h" +#include "proof/alf/alf_node_converter.h" + +namespace cvc5::internal { +namespace proof { + +/** + * This node converter adds applications of "singleton elimination" to + * accurately reflect the difference in semantics between ALF and RARE. + * + * This is used when printing RARE rules in ALF. For example, the RARE rule: + * + * (define-rule bool-or-false ((xs Bool :list) (ys Bool :list)) + * (or xs false ys) + * (or xs ys)) + * + * becomes the ALF rule: + * + * (declare-rule dsl.bool-or-false ((xs Bool :list) (ys Bool :list)) + * :args (xs ys) + * :conclusion (= (or xs false ys)) ($dsl.singleton_elim (or xs ys))) + * ) + * + * Where note that $dsl.singleton_elim is defined in our ALF signature: + * + * (program $dsl.singleton_elim + * ((T Type) (S Type) (U Type) (f (-> T U S)) (x S) (x1 T) (x2 T :list)) + * (S) S + * ( + * (($dsl.singleton_elim (f x1 x2)) + * (alf.ite (alf.is_eq x2 (alf.nil f x1 x2)) x1 (f x1 x2))) + * (($dsl.singleton_elim x) + * x) + * ) + * ) + * + * In the above rule, notice that $dsl.singleton_elim is applied to (or xs ys). + * The reason is that (or xs ys) *may* become a singleton list when xs and ys + * are instantiated. Say xs -> [A] and ys -> []. In RARE, the conclusion is + * (= (or A false) A) + * In ALF, the conclusion is: + * (= (or A false) (or A)) + * + * The above transformation takes into account the difference in semantics. + * More generally, we apply $dsl.singleton_elim to any subterm of the input + * term that has fewer than 2 children that are not marked with :list. + */ +class AlfListNodeConverter : public NodeConverter +{ + public: + AlfListNodeConverter(NodeManager* nm, BaseAlfNodeConverter& tproc); + /** Convert node n based on the conversion described above. */ + Node postConvert(Node n) override; + + private: + /** The parent converter, used for getting internal symbols and utilities */ + BaseAlfNodeConverter& d_tproc; +}; + +} // namespace proof +} // namespace cvc5::internal + +#endif diff --git a/src/proof/alf/alf_node_converter.cpp b/src/proof/alf/alf_node_converter.cpp index bae4588d3f6..2c6963bafca 100644 --- a/src/proof/alf/alf_node_converter.cpp +++ b/src/proof/alf/alf_node_converter.cpp @@ -379,45 +379,6 @@ Node AlfNodeConverter::mkNil(TypeNode tn) return mkInternalSymbol("alf.nil", tn); } -Node AlfNodeConverter::getNullTerminator(Kind k, TypeNode tn) -{ - // note this method should remain in sync with getCongRule in - // proof_node_algorithm.cpp. - switch (k) - { - case Kind::APPLY_UF: - case Kind::DISTINCT: - case Kind::FLOATINGPOINT_LT: - case Kind::FLOATINGPOINT_LEQ: - case Kind::FLOATINGPOINT_GT: - case Kind::FLOATINGPOINT_GEQ: - // the above operators may take arbitrary number of arguments but are not - // marked as n-ary in ALF - return Node::null(); - case Kind::APPLY_CONSTRUCTOR: - // tuple constructor is n-ary with unit tuple as null terminator - if (tn.isTuple()) - { - TypeNode tnu = NodeManager::currentNM()->mkTupleType({}); - return NodeManager::currentNM()->mkGroundValue(tnu); - } - return Node::null(); - break; - case Kind::OR: return NodeManager::currentNM()->mkConst(false); - case Kind::SEP_STAR: - case Kind::AND: return NodeManager::currentNM()->mkConst(true); - case Kind::ADD: return NodeManager::currentNM()->mkConstInt(Rational(0)); - case Kind::MULT: - case Kind::NONLINEAR_MULT: - return NodeManager::currentNM()->mkConstInt(Rational(1)); - case Kind::BITVECTOR_CONCAT: - return mkInternalSymbol("@bvempty", - NodeManager::currentNM()->mkBitVectorType(0)); - default: break; - } - return mkNil(tn); -} - Node AlfNodeConverter::mkList(const std::vector& args) { TypeNode tn = NodeManager::currentNM()->booleanType(); diff --git a/src/proof/alf/alf_node_converter.h b/src/proof/alf/alf_node_converter.h index 89996a43d24..ac729d4c8f4 100644 --- a/src/proof/alf/alf_node_converter.h +++ b/src/proof/alf/alf_node_converter.h @@ -51,7 +51,6 @@ class BaseAlfNodeConverter : public NodeConverter * passed as arguments to terms and proof rules. */ virtual Node typeAsNode(TypeNode tni) = 0; - /** * Make an internal symbol with custom name. This is a BOUND_VARIABLE that * has a distinguished status so that it is *not* printed as (bvar ...). The @@ -94,15 +93,6 @@ class AlfNodeConverter : public BaseAlfNodeConverter * @return the operator. */ Node getOperatorOfTerm(Node n, bool reqCast = false) override; - /** - * Get the null terminator for kind k and type tn. The type tn can be - * omitted if applications of kind k do not have parametric type. - * - * The returned null terminator is *not* converted to internal form. - * - * For examples of null terminators, see nary_term_utils.h. - */ - Node getNullTerminator(Kind k, TypeNode tn); /** Make generic list */ Node mkList(const std::vector& args); /** From 4f78aaea95f684db80cc2d7c3e3f31a61dadf8bc Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Mon, 22 Apr 2024 13:05:06 -0500 Subject: [PATCH 4/6] Add rel.join operator for relations. (#10433) --- include/cvc5/cvc5_kind.h | 32 +++ src/api/cpp/cvc5.cpp | 8 + src/parser/smt2/smt2_state.cpp | 5 +- src/printer/smt2/smt2_printer.cpp | 2 + src/theory/bags/bags_utils.cpp | 2 - src/theory/builtin/generic_op.cpp | 21 +- src/theory/inference_id.cpp | 3 + src/theory/inference_id.h | 2 + src/theory/quantifiers/term_util.cpp | 4 +- src/theory/sets/kinds | 13 ++ src/theory/sets/theory_sets.cpp | 1 + src/theory/sets/theory_sets_private.cpp | 8 +- src/theory/sets/theory_sets_rels.cpp | 192 +++++++++++++++++- src/theory/sets/theory_sets_rels.h | 35 ++++ src/theory/sets/theory_sets_rewriter.cpp | 47 +++++ src/theory/sets/theory_sets_rewriter.h | 1 + src/theory/sets/theory_sets_type_rules.cpp | 96 ++++++++- src/theory/sets/theory_sets_type_rules.h | 18 ++ src/theory/theory_engine.cpp | 5 + test/regress/cli/CMakeLists.txt | 1 + .../cli/regress1/rels/relation_join1.smt2 | 27 +++ 21 files changed, 501 insertions(+), 22 deletions(-) create mode 100644 test/regress/cli/regress1/rels/relation_join1.smt2 diff --git a/include/cvc5/cvc5_kind.h b/include/cvc5/cvc5_kind.h index 3e86dbfef30..3c9e2b3d64a 100644 --- a/include/cvc5/cvc5_kind.h +++ b/include/cvc5/cvc5_kind.h @@ -3555,6 +3555,38 @@ enum ENUM(Kind) : int32_t * - Solver::mkOp(Kind, const std::vector&) const */ EVALUE(RELATION_JOIN), + /** + * \rst + * Table join operator for relations has the form + * :math:`((\_ \; rel.table\_join \; m_1 \; n_1 \; \dots \; m_k \; n_k) \; A \; B)` + * where :math:`m_1 \; n_1 \; \dots \; m_k \; n_k` are natural numbers, + * and :math:`A, B` are relations. + * This operator filters the product of two sets based on the equality of + * projected tuples using indices :math:`m_1, \dots, m_k` in relation :math:`A`, + * and indices :math:`n_1, \dots, n_k` in relation :math:`B`. + * + * - Arity: ``2`` + * + * - ``1:`` Term of relation Sort + * + * - ``2:`` Term of relation Sort + * + * - Indices: ``n`` + * - ``1..n:`` Indices of the projection + * + * \endrst + * - Create Term of this Kind with: + * - Solver::mkTerm(const Op&, const std::vector&) const + * + * - Create Op of this kind with: + * - Solver::mkOp(Kind, const std::vector&) const + * + * \rst + * .. warning:: This kind is experimental and may be changed or removed in + * future versions. + * \endrst + */ + EVALUE(RELATION_TABLE_JOIN), /** * Relation cartesian product. * diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 0e7c9c994bd..48fa89ca39e 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -340,6 +340,8 @@ const static std::unordered_map> KIND_ENUM(Kind::SET_FOLD, internal::Kind::SET_FOLD), /* Relations -------------------------------------------------------- */ KIND_ENUM(Kind::RELATION_JOIN, internal::Kind::RELATION_JOIN), + KIND_ENUM(Kind::RELATION_TABLE_JOIN, + internal::Kind::RELATION_TABLE_JOIN), KIND_ENUM(Kind::RELATION_PRODUCT, internal::Kind::RELATION_PRODUCT), KIND_ENUM(Kind::RELATION_TRANSPOSE, internal::Kind::RELATION_TRANSPOSE), KIND_ENUM(Kind::RELATION_TCLOSURE, internal::Kind::RELATION_TCLOSURE), @@ -735,6 +737,8 @@ const static std::unordered_map s_op_kinds{ {Kind::RELATION_AGGREGATE, internal::Kind::RELATION_AGGREGATE_OP}, {Kind::RELATION_GROUP, internal::Kind::RELATION_GROUP_OP}, {Kind::RELATION_PROJECT, internal::Kind::RELATION_PROJECT_OP}, + {Kind::RELATION_TABLE_JOIN, internal::Kind::RELATION_TABLE_JOIN_OP}, {Kind::TABLE_PROJECT, internal::Kind::TABLE_PROJECT_OP}, {Kind::TABLE_AGGREGATE, internal::Kind::TABLE_AGGREGATE_OP}, {Kind::TABLE_JOIN, internal::Kind::TABLE_JOIN_OP}, @@ -2131,6 +2136,7 @@ size_t Op::getNumIndicesHelper() const case Kind::RELATION_AGGREGATE: case Kind::RELATION_GROUP: case Kind::RELATION_PROJECT: + case Kind::RELATION_TABLE_JOIN: case Kind::TABLE_AGGREGATE: case Kind::TABLE_GROUP: case Kind::TABLE_JOIN: @@ -2302,6 +2308,7 @@ Term Op::getIndexHelper(size_t index) case Kind::RELATION_AGGREGATE: case Kind::RELATION_GROUP: case Kind::RELATION_PROJECT: + case Kind::RELATION_TABLE_JOIN: case Kind::TABLE_AGGREGATE: case Kind::TABLE_GROUP: case Kind::TABLE_JOIN: @@ -5898,6 +5905,7 @@ Op TermManager::mkOp(Kind kind, const std::vector& args) case Kind::RELATION_AGGREGATE: case Kind::RELATION_GROUP: case Kind::RELATION_PROJECT: + case Kind::RELATION_TABLE_JOIN: case Kind::TABLE_AGGREGATE: case Kind::TABLE_GROUP: case Kind::TABLE_JOIN: diff --git a/src/parser/smt2/smt2_state.cpp b/src/parser/smt2/smt2_state.cpp index 1777b0085b1..b7a9691b369 100644 --- a/src/parser/smt2/smt2_state.cpp +++ b/src/parser/smt2/smt2_state.cpp @@ -867,6 +867,7 @@ void Smt2State::setLogic(std::string name) addOperator(Kind::SET_FILTER, "set.filter"); addOperator(Kind::SET_FOLD, "set.fold"); addOperator(Kind::RELATION_JOIN, "rel.join"); + addOperator(Kind::RELATION_TABLE_JOIN, "rel.table_join"); addOperator(Kind::RELATION_PRODUCT, "rel.product"); addOperator(Kind::RELATION_TRANSPOSE, "rel.transpose"); addOperator(Kind::RELATION_TCLOSURE, "rel.tclosure"); @@ -877,6 +878,7 @@ void Smt2State::setLogic(std::string name) addOperator(Kind::RELATION_AGGREGATE, "rel.aggr"); addOperator(Kind::RELATION_PROJECT, "rel.project"); addIndexedOperator(Kind::RELATION_GROUP, "rel.group"); + addIndexedOperator(Kind::RELATION_TABLE_JOIN, "rel.table_join"); addIndexedOperator(Kind::RELATION_AGGREGATE, "rel.aggr"); addIndexedOperator(Kind::RELATION_PROJECT, "rel.project"); // set.comprehension is a closure kind @@ -1332,7 +1334,8 @@ Term Smt2State::applyParseOp(const ParseOp& p, std::vector& args) if (kind == Kind::TUPLE_PROJECT || kind == Kind::TABLE_PROJECT || kind == Kind::TABLE_AGGREGATE || kind == Kind::TABLE_JOIN || kind == Kind::TABLE_GROUP || kind == Kind::RELATION_GROUP - || kind == Kind::RELATION_AGGREGATE || kind == Kind::RELATION_PROJECT) + || kind == Kind::RELATION_AGGREGATE || kind == Kind::RELATION_PROJECT + || kind == Kind::RELATION_TABLE_JOIN) { std::vector indices; Op op = d_tm.mkOp(kind, indices); diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 4089cf212c5..84f3ddf3b09 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -514,6 +514,7 @@ bool Smt2Printer::toStreamBase(std::ostream& out, case Kind::RELATION_GROUP_OP: case Kind::RELATION_AGGREGATE_OP: case Kind::RELATION_PROJECT_OP: + case Kind::RELATION_TABLE_JOIN_OP: { ProjectOp op = n.getConst(); const std::vector& indices = op.getIndices(); @@ -1281,6 +1282,7 @@ std::string Smt2Printer::smtKindString(Kind k) case Kind::SET_FILTER: return "set.filter"; case Kind::SET_FOLD: return "set.fold"; case Kind::RELATION_JOIN: return "rel.join"; + case Kind::RELATION_TABLE_JOIN: return "rel.table_join"; case Kind::RELATION_PRODUCT: return "rel.product"; case Kind::RELATION_TRANSPOSE: return "rel.transpose"; case Kind::RELATION_TCLOSURE: return "rel.tclosure"; diff --git a/src/theory/bags/bags_utils.cpp b/src/theory/bags/bags_utils.cpp index 4566bb57a34..b782e4a902c 100644 --- a/src/theory/bags/bags_utils.cpp +++ b/src/theory/bags/bags_utils.cpp @@ -997,8 +997,6 @@ Node BagsUtils::evaluateTableProject(TNode n) std::pair, std::vector> BagsUtils::splitTableJoinIndices(Node n) { - Assert(n.getKind() == Kind::TABLE_JOIN && n.hasOperator() - && n.getOperator().getKind() == Kind::TABLE_JOIN_OP); ProjectOp op = n.getOperator().getConst(); const std::vector& indices = op.getIndices(); size_t joinSize = indices.size() / 2; diff --git a/src/theory/builtin/generic_op.cpp b/src/theory/builtin/generic_op.cpp index e21ac5fc4b7..3b477355168 100644 --- a/src/theory/builtin/generic_op.cpp +++ b/src/theory/builtin/generic_op.cpp @@ -58,16 +58,20 @@ bool GenericOp::isNumeralIndexedOperatorKind(Kind k) return k == Kind::REGEXP_LOOP || k == Kind::BITVECTOR_EXTRACT || k == Kind::BITVECTOR_REPEAT || k == Kind::BITVECTOR_ZERO_EXTEND || k == Kind::BITVECTOR_SIGN_EXTEND || k == Kind::BITVECTOR_ROTATE_LEFT - || k == Kind::BITVECTOR_ROTATE_RIGHT || k == Kind::INT_TO_BITVECTOR || k==Kind::BITVECTOR_BITOF - || k == Kind::IAND || k == Kind::FLOATINGPOINT_TO_FP_FROM_FP + || k == Kind::BITVECTOR_ROTATE_RIGHT || k == Kind::INT_TO_BITVECTOR + || k == Kind::BITVECTOR_BITOF || k == Kind::IAND + || k == Kind::FLOATINGPOINT_TO_FP_FROM_FP || k == Kind::FLOATINGPOINT_TO_FP_FROM_IEEE_BV || k == Kind::FLOATINGPOINT_TO_FP_FROM_SBV - || k == Kind::FLOATINGPOINT_TO_FP_FROM_REAL - || k == Kind::FLOATINGPOINT_TO_FP_FROM_UBV || k == Kind::FLOATINGPOINT_TO_SBV || k == Kind::FLOATINGPOINT_TO_UBV + || k == Kind::FLOATINGPOINT_TO_FP_FROM_REAL + || k == Kind::FLOATINGPOINT_TO_FP_FROM_UBV + || k == Kind::FLOATINGPOINT_TO_SBV || k == Kind::FLOATINGPOINT_TO_UBV || k == Kind::FLOATINGPOINT_TO_SBV_TOTAL - || k == Kind::FLOATINGPOINT_TO_UBV_TOTAL || k == Kind::RELATION_AGGREGATE - || k == Kind::RELATION_PROJECT || k == Kind::RELATION_GROUP || k == Kind::TABLE_PROJECT - || k == Kind::TABLE_AGGREGATE || k == Kind::TABLE_JOIN || k == Kind::TABLE_GROUP; + || k == Kind::FLOATINGPOINT_TO_UBV_TOTAL + || k == Kind::RELATION_AGGREGATE || k == Kind::RELATION_PROJECT + || k == Kind::RELATION_GROUP || k == Kind::TABLE_PROJECT + || k == Kind::RELATION_TABLE_JOIN || k == Kind::TABLE_AGGREGATE + || k == Kind::TABLE_JOIN || k == Kind::TABLE_GROUP; } bool GenericOp::isIndexedOperatorKind(Kind k) @@ -194,6 +198,7 @@ std::vector GenericOp::getIndicesForOperator(Kind k, Node n) break; case Kind::RELATION_AGGREGATE: case Kind::RELATION_PROJECT: + case Kind::RELATION_TABLE_JOIN: case Kind::RELATION_GROUP: case Kind::TABLE_PROJECT: case Kind::TABLE_AGGREGATE: @@ -342,6 +347,8 @@ Node GenericOp::getOperatorForIndices(Kind k, const std::vector& indices) return nm->mkConst(Kind::RELATION_AGGREGATE_OP, ProjectOp(numerals)); case Kind::RELATION_PROJECT: return nm->mkConst(Kind::RELATION_PROJECT_OP, ProjectOp(numerals)); + case Kind::RELATION_TABLE_JOIN: + return nm->mkConst(Kind::RELATION_TABLE_JOIN_OP, ProjectOp(numerals)); case Kind::RELATION_GROUP: return nm->mkConst(Kind::RELATION_GROUP_OP, ProjectOp(numerals)); case Kind::TABLE_PROJECT: diff --git a/src/theory/inference_id.cpp b/src/theory/inference_id.cpp index 26863069cdd..38f0176860c 100644 --- a/src/theory/inference_id.cpp +++ b/src/theory/inference_id.cpp @@ -409,6 +409,9 @@ const char* toString(InferenceId i) case InferenceId::SETS_RELS_JOIN_IMAGE_UP: return "SETS_RELS_JOIN_IMAGE_UP"; case InferenceId::SETS_RELS_JOIN_SPLIT_1: return "SETS_RELS_JOIN_SPLIT_1"; case InferenceId::SETS_RELS_JOIN_SPLIT_2: return "SETS_RELS_JOIN_SPLIT_2"; + case InferenceId::SETS_RELS_TABLE_JOIN_UP: return "SETS_RELS_TABLE_JOIN_UP"; + case InferenceId::SETS_RELS_TABLE_JOIN_DOWN: + return "SETS_RELS_TABLE_JOIN_DOWN"; case InferenceId::SETS_RELS_PRODUCE_COMPOSE: return "SETS_RELS_PRODUCE_COMPOSE"; case InferenceId::SETS_RELS_PRODUCT_SPLIT: return "SETS_RELS_PRODUCT_SPLIT"; diff --git a/src/theory/inference_id.h b/src/theory/inference_id.h index 764f83788bb..97fa41c88b7 100644 --- a/src/theory/inference_id.h +++ b/src/theory/inference_id.h @@ -575,6 +575,8 @@ enum class InferenceId SETS_RELS_JOIN_IMAGE_UP, SETS_RELS_JOIN_SPLIT_1, SETS_RELS_JOIN_SPLIT_2, + SETS_RELS_TABLE_JOIN_UP, + SETS_RELS_TABLE_JOIN_DOWN, SETS_RELS_PRODUCE_COMPOSE, SETS_RELS_PRODUCT_SPLIT, SETS_RELS_TCLOSURE_FWD, diff --git a/src/theory/quantifiers/term_util.cpp b/src/theory/quantifiers/term_util.cpp index 18e48e562c4..e73ff5baf4b 100644 --- a/src/theory/quantifiers/term_util.cpp +++ b/src/theory/quantifiers/term_util.cpp @@ -323,8 +323,8 @@ bool TermUtil::isAssoc(Kind k, bool reqNAry) || k == Kind::BITVECTOR_XOR || k == Kind::BITVECTOR_XNOR || k == Kind::BITVECTOR_CONCAT || k == Kind::STRING_CONCAT || k == Kind::SET_UNION || k == Kind::SET_INTER - || k == Kind::RELATION_JOIN || k == Kind::RELATION_PRODUCT - || k == Kind::SEP_STAR; + || k == Kind::RELATION_JOIN || k == Kind::RELATION_TABLE_JOIN + || k == Kind::RELATION_PRODUCT || k == Kind::SEP_STAR; } bool TermUtil::isComm(Kind k, bool reqNAry) diff --git a/src/theory/sets/kinds b/src/theory/sets/kinds index 07db0da7ebc..107bb5bdd66 100644 --- a/src/theory/sets/kinds +++ b/src/theory/sets/kinds @@ -160,4 +160,17 @@ typerule RELATION_PROJECT ::cvc5::internal::theory::sets::RelationProjectT construle SET_UNION ::cvc5::internal::theory::sets::SetsBinaryOperatorTypeRule construle SET_SINGLETON ::cvc5::internal::theory::sets::SingletonTypeRule +# rel.table_join operator +constant RELATION_TABLE_JOIN_OP \ + class \ + ProjectOp+ \ + ::cvc5::internal::ProjectOpHashFunction \ + "theory/datatypes/project_op.h" \ + "operator for RELATION_TABLE_JOIN; payload is an instance of the cvc5::internal::ProjectOp class" + +parameterized RELATION_TABLE_JOIN RELATION_TABLE_JOIN_OP 2 "relation table join" + +typerule RELATION_TABLE_JOIN_OP "SimpleTypeRule" +typerule RELATION_TABLE_JOIN ::cvc5::internal::theory::sets::RelationTableJoinTypeRule + endtheory diff --git a/src/theory/sets/theory_sets.cpp b/src/theory/sets/theory_sets.cpp index 001b5d96f99..e1df04e572a 100644 --- a/src/theory/sets/theory_sets.cpp +++ b/src/theory/sets/theory_sets.cpp @@ -85,6 +85,7 @@ void TheorySets::finishInit() // relation operators d_equalityEngine->addFunctionKind(Kind::RELATION_PRODUCT); d_equalityEngine->addFunctionKind(Kind::RELATION_JOIN); + d_equalityEngine->addFunctionKind(Kind::RELATION_TABLE_JOIN); d_equalityEngine->addFunctionKind(Kind::RELATION_TRANSPOSE); d_equalityEngine->addFunctionKind(Kind::RELATION_TCLOSURE); d_equalityEngine->addFunctionKind(Kind::RELATION_JOIN_IMAGE); diff --git a/src/theory/sets/theory_sets_private.cpp b/src/theory/sets/theory_sets_private.cpp index c513ef53bc2..eec51a0c151 100644 --- a/src/theory/sets/theory_sets_private.cpp +++ b/src/theory/sets/theory_sets_private.cpp @@ -1470,8 +1470,12 @@ bool TheorySetsPrivate::collectModelValues(TheoryModel* m, const std::set& termSet) { Trace("sets-model") << "Set collect model values" << std::endl; - - NodeManager* nm = nodeManager(); + Trace("sets-model") << "termSet: " << termSet << std::endl; + if(TraceIsOn("sets-model")) + { + Trace("sets-model") <debugPrintModelEqc(); + } + NodeManager* nm = NodeManager::currentNM(); std::map mvals; // If cardinality is enabled, we need to use the ordered equivalence class // list computed by the cardinality solver, where sets equivalence classes diff --git a/src/theory/sets/theory_sets_rels.cpp b/src/theory/sets/theory_sets_rels.cpp index a541f2adc41..c66c4f63b60 100644 --- a/src/theory/sets/theory_sets_rels.cpp +++ b/src/theory/sets/theory_sets_rels.cpp @@ -18,6 +18,7 @@ #include "expr/dtype.h" #include "expr/dtype_cons.h" #include "expr/skolem_manager.h" +#include "theory/datatypes/project_op.h" #include "theory/datatypes/tuple_utils.h" #include "theory/sets/theory_sets.h" #include "theory/sets/theory_sets_private.h" @@ -100,6 +101,14 @@ void TheorySetsRels::check(Theory::Effort level) applyJoinRule( join_terms[j], rel_rep, exp ); } } + if (kind_terms.find(Kind::RELATION_TABLE_JOIN) != kind_terms.end()) + { + std::vector& joinTerms = kind_terms[Kind::RELATION_TABLE_JOIN]; + for (size_t j = 0; j < joinTerms.size(); j++) + { + applyTableJoinRule(joinTerms[j], rel_rep, exp); + } + } if (kind_terms.find(Kind::RELATION_PRODUCT) != kind_terms.end()) { std::vector& product_terms = kind_terms[Kind::RELATION_PRODUCT]; @@ -150,7 +159,8 @@ void TheorySetsRels::check(Theory::Effort level) << " terms of kind " << k_t_it->first << std::endl; std::vector::iterator term_it = k_t_it->second.begin(); if (k_t_it->first == Kind::RELATION_JOIN - || k_t_it->first == Kind::RELATION_PRODUCT) + || k_t_it->first == Kind::RELATION_PRODUCT + || k_t_it->first == Kind::RELATION_TABLE_JOIN) { while (term_it != k_t_it->second.end()) { @@ -258,6 +268,7 @@ void TheorySetsRels::check(Theory::Effort level) { if (eqc_node.getKind() == Kind::RELATION_TRANSPOSE || eqc_node.getKind() == Kind::RELATION_JOIN + || eqc_node.getKind() == Kind::RELATION_TABLE_JOIN || eqc_node.getKind() == Kind::RELATION_PRODUCT || eqc_node.getKind() == Kind::RELATION_TCLOSURE || eqc_node.getKind() == Kind::RELATION_JOIN_IMAGE @@ -953,6 +964,66 @@ void TheorySetsRels::check(Theory::Effort level) makeSharedTerm(shared_x); } + void TheorySetsRels::applyTableJoinRule(Node n, Node nRep, Node exp) + { + Trace("rels-debug") << "\n[Theory::Rels] *********** Applying " + "RELATION_TABLE_JOIN rule on joined term = " + << n << ", its representative = " << nRep + << " with explanation = " << exp << std::endl; + if (d_rel_nodes.find(n) == d_rel_nodes.end()) + { + Trace("rels-debug") + << "\n[Theory::Rels] Apply RELATION_TABLE_JOIN-COMPOSE rule on term: " + << n << " with explanation: " << exp << std::endl; + + computeMembersForBinOpRel(n); + d_rel_nodes.insert(n); + } + NodeManager* nm = NodeManager::currentNM(); + Node A = n[0]; + Node B = n[1]; + Node e = exp[0]; + + Node repA = getRepresentative(A); + Node repB = getRepresentative(B); + + TypeNode tupleAType = A.getType().getSetElementType(); + TypeNode tupleBType = B.getType().getSetElementType(); + size_t tupleALength = tupleAType.getTupleLength(); + size_t productTupleLength = + n.getType().getSetElementType().getTupleLength(); + + std::vector elements = TupleUtils::getTupleElements(e); + Node a = TupleUtils::constructTupleFromElements( + tupleAType, elements, 0, tupleALength - 1); + Node b = TupleUtils::constructTupleFromElements( + tupleBType, elements, tupleALength, productTupleLength - 1); + + computeTupleReps(a); + computeTupleReps(b); + + const std::vector& indices = + n.getOperator().getConst().getIndices(); + Node joinConstraints = d_trueNode; + for (size_t i = 0; i < indices.size(); i += 2) + { + Node x = elements[indices[i]]; + Node y = elements[tupleALength + indices[i + 1]]; + Node equal = x.eqNode(y); + joinConstraints = joinConstraints.andNode(equal); + } + + Node fact1 = nm->mkNode(Kind::SET_MEMBER, a, A); + Node fact2 = nm->mkNode(Kind::SET_MEMBER, b, B); + Node premise = exp; + if (n != exp[1]) + { + premise = premise.andNode(n.eqNode(exp[1])); + } + Node conclusion = fact1.andNode(fact2).andNode(joinConstraints); + sendInfer(conclusion, InferenceId::SETS_RELS_TABLE_JOIN_DOWN, premise); + } + /* * transpose-occur rule: (a, b) IS_IN X (RELATION_TRANSPOSE X) in T * --------------------------------------- @@ -1035,6 +1106,7 @@ void TheorySetsRels::check(Theory::Effort level) } case Kind::RELATION_JOIN: case Kind::RELATION_PRODUCT: + case Kind::RELATION_TABLE_JOIN: { computeMembersForBinOpRel(rel[0]); break; @@ -1050,6 +1122,7 @@ void TheorySetsRels::check(Theory::Effort level) } case Kind::RELATION_JOIN: case Kind::RELATION_PRODUCT: + case Kind::RELATION_TABLE_JOIN: { computeMembersForBinOpRel(rel[1]); break; @@ -1057,7 +1130,24 @@ void TheorySetsRels::check(Theory::Effort level) default: break; } - composeMembersForRels(rel); + Kind k = rel.getKind(); + switch (k) + { + case Kind::RELATION_JOIN: + case Kind::RELATION_PRODUCT: + { + composeMembersForRels(rel); + break; + } + case Kind::RELATION_TABLE_JOIN: + { + applyTableJoinUp(rel); + break; + } + default: + Assert(false) << "No implementation for up rules for kind " << k + << std::endl; + } } // Bottom-up fashion to compute unary relation @@ -1068,7 +1158,8 @@ void TheorySetsRels::check(Theory::Effort level) case Kind::RELATION_TRANSPOSE: case Kind::RELATION_TCLOSURE: computeMembersForUnaryOpRel(rel[0]); break; case Kind::RELATION_JOIN: - case Kind::RELATION_PRODUCT: computeMembersForBinOpRel(rel[0]); break; + case Kind::RELATION_PRODUCT: + case Kind::RELATION_TABLE_JOIN: computeMembersForBinOpRel(rel[0]); break; default: break; } @@ -1214,6 +1305,94 @@ void TheorySetsRels::check(Theory::Effort level) } + void TheorySetsRels::applyTableJoinUp(Node n) + { + Assert(n.getKind() == Kind::RELATION_TABLE_JOIN); + Trace("rels-debug") + << "[Theory::Rels] Start composing members for relation = " << n + << std::endl; + Node a = n[0]; + Node b = n[1]; + Node aRep = getRepresentative(a); + Node bRep = getRepresentative(b); + + if (d_rReps_memberReps_cache.find(aRep) == d_rReps_memberReps_cache.end() + || d_rReps_memberReps_cache.find(bRep) + == d_rReps_memberReps_cache.end()) + { + // no members found for a, b + return; + } + + NodeManager* nm = NodeManager::currentNM(); + + std::vector aMemberships = d_rReps_memberReps_exp_cache[aRep]; + std::vector bMemberships = d_rReps_memberReps_exp_cache[bRep]; + const std::vector& indices = + n.getOperator().getConst().getIndices(); + for (unsigned int i = 0; i < aMemberships.size(); i++) + { + for (unsigned int j = 0; j < bMemberships.size(); j++) + { + Node aConstraint = aMemberships[i]; + Node bConstraint = bMemberships[j]; + Node e1 = aConstraint[0]; + Node e2 = bConstraint[0]; + TypeNode elementType = n.getType().getSetElementType(); + Node tuple = TupleUtils::concatTuples(elementType, e1, e2); + std::vector reasons; + + std::vector aElements = TupleUtils::getTupleElements(e1); + std::vector bElements = TupleUtils::getTupleElements(e2); + + // whether e1, e2 have matching join elements + bool notMatched = false; + for (size_t k = 0; k < indices.size(); k += 2) + { + Node x = aElements[indices[k]]; + Node y = bElements[indices[k + 1]]; + + // Since we require notification x and y are equal, + // they must be shared terms of theory of sets. Hence, we make the + // following calls to makeSharedTerm to ensure this is the case. + makeSharedTerm(x); + makeSharedTerm(y); + + if (!areEqual(x, y)) + { + notMatched = true; + break; + } + else if (x != y) + { + Trace("rels-debug") << "...equal" << std::endl; + reasons.push_back(nm->mkNode(Kind::EQUAL, x, y)); + } + } + + if (notMatched) + { + continue; + } + + Node fact = nm->mkNode(Kind::SET_MEMBER, tuple, n); + reasons.push_back(aConstraint); + reasons.push_back(bConstraint); + if (a != aConstraint[1]) + { + reasons.push_back(nm->mkNode(Kind::EQUAL, a, aConstraint[1])); + } + if (b != bConstraint[1]) + { + reasons.push_back(nm->mkNode(Kind::EQUAL, b, bConstraint[1])); + } + sendInfer(fact, + InferenceId::SETS_RELS_TABLE_JOIN_UP, + nm->mkNode(Kind::AND, reasons)); + } + } + } + void TheorySetsRels::processInference(Node conc, InferenceId id, Node exp) { Trace("sets-pinfer") << "Process inference: " << exp << " => " << conc @@ -1232,8 +1411,9 @@ void TheorySetsRels::check(Theory::Effort level) bool TheorySetsRels::isRelationKind( Kind k ) { return k == Kind::RELATION_TRANSPOSE || k == Kind::RELATION_PRODUCT - || k == Kind::RELATION_JOIN || k == Kind::RELATION_TCLOSURE - || k == Kind::RELATION_IDEN || k == Kind::RELATION_JOIN_IMAGE; + || k == Kind::RELATION_JOIN || k == Kind::RELATION_TABLE_JOIN + || k == Kind::RELATION_TCLOSURE || k == Kind::RELATION_IDEN + || k == Kind::RELATION_JOIN_IMAGE; } Node TheorySetsRels::getRepresentative( Node t ) { @@ -1414,7 +1594,7 @@ void TheorySetsRels::check(Theory::Effort level) bool TupleTrie::addTerm( Node n, std::vector< Node >& reps, int argIndex ){ if( argIndex==(int)reps.size() ){ if( d_data.empty() ){ - //store n in d_data (this should be interpretted as the "data" and not as a reference to a child) + //store n in d_data (this should be interpreted as the "data" and not as a reference to a child) d_data[n].clear(); return true; }else{ diff --git a/src/theory/sets/theory_sets_rels.h b/src/theory/sets/theory_sets_rels.h index 4fe1682ad4f..ac1f8d3d01c 100644 --- a/src/theory/sets/theory_sets_rels.h +++ b/src/theory/sets/theory_sets_rels.h @@ -34,6 +34,12 @@ namespace sets { class TheorySetsPrivate; +/** + * A prefix tree for tuples and their elements' representatives. + * Suppose we have a tuple representative t = , + * then the tuple tree would be + * e1 -> e2 -> ... -> e_n -> t +*/ class TupleTrie { public: /** the data */ @@ -100,7 +106,9 @@ class TheorySetsRels : protected EnvObj NodeSet d_shared_terms; std::unordered_set d_rel_nodes; + /** a map from tuples to their elements' representatives*/ std::map< Node, std::vector > d_tuple_reps; + /** a map from relation terms to their member tuples*/ std::map< Node, TupleTrie > d_membership_trie; /** Symbolic tuple variables that has been reduced to concrete ones */ @@ -149,6 +157,20 @@ class TheorySetsRels : protected EnvObj void applyTransposeRule( Node rel, Node rel_rep, Node exp ); void applyProductRule( Node rel, Node rel_rep, Node exp ); void applyJoinRule( Node rel, Node rel_rep, Node exp); + /** + * @param n is a ((_ table.join m1 n1 ... mk nk) A B) where A, B are tables + * @param nRep a representative of n + * @param exp a membership constraint of the form (set.member e n) + * where e is an element of the form (tuple a1 ... am b1 ... bn) + * This function sends a fact that represents the following + * (=> + * (set.member e n) + * (and + * (= a_{m1} b_{n1}) ... (= a_{mk} b_{nk}) + * (set.member (tuple a1 ... am) A) + * (set.member (tuple b1 ... bn) B))) + */ + void applyTableJoinRule(Node n, Node nRep, Node exp); void applyJoinImageRule( Node mem_rep, Node rel_rep, Node exp); void applyIdenRule( Node mem_rep, Node rel_rep, Node exp); void applyTCRule( Node mem, Node rel, Node rel_rep, Node exp); @@ -166,6 +188,19 @@ class TheorySetsRels : protected EnvObj std::unordered_set& seen); void composeMembersForRels( Node ); + /** + * @param n is ((_ rel.join m1 n1 ... mk nk) A B) where A, B are relations + * This functions looks for current members of A, B. + * For each pair e1 = (tuple a1 ... am) in A, e2 = (tuple b1 ... bn) in B + * this function sends the following fact + * (=> + * (and + * (set.member e1 A) + * (set.member e2 B) + * (= a_{m1} b_{n1}) ... (= a_{mk} b_{nk})) + * (set.member (tuple a1 ... am b1 ... bn) n)) + */ + void applyTableJoinUp(Node); void computeMembersForBinOpRel( Node ); void computeMembersForIdenTerm( Node ); void computeMembersForUnaryOpRel( Node ); diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp index 80315c1299c..a8277cea151 100644 --- a/src/theory/sets/theory_sets_rewriter.cpp +++ b/src/theory/sets/theory_sets_rewriter.cpp @@ -20,6 +20,7 @@ #include "expr/dtype_cons.h" #include "expr/elim_shadow_converter.h" #include "options/sets_options.h" +#include "theory/bags/bags_utils.h" #include "theory/datatypes/tuple_utils.h" #include "theory/sets/normal_form.h" #include "theory/sets/rels_utils.h" @@ -347,6 +348,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { case Kind::SET_COMPREHENSION: return postRewriteComprehension(node); break; + case Kind::RELATION_TABLE_JOIN: return postRewriteTableJoin(node); break; case Kind::SET_MAP: return postRewriteMap(node); case Kind::SET_FILTER: return postRewriteFilter(node); case Kind::SET_FOLD: return postRewriteFold(node); @@ -666,6 +668,51 @@ RewriteResponse TheorySetsRewriter::postRewriteComprehension(TNode n) return RewriteResponse(REWRITE_DONE, n); } +RewriteResponse TheorySetsRewriter::postRewriteTableJoin(TNode n) +{ + Assert(n.getKind() == Kind::RELATION_TABLE_JOIN); + + Node A = n[0]; + Node B = n[1]; + TypeNode tupleType = n.getType().getSetElementType(); + if (A.isConst() && B.isConst()) + { + auto [aIndices, bIndices] = bags::BagsUtils::splitTableJoinIndices(n); + + std::set elementsA = NormalForm::getElementsFromNormalConstant(A); + std::set elementsB = NormalForm::getElementsFromNormalConstant(B); + std::set newSet; + + for (const auto& a : elementsA) + { + for (const auto& b : elementsB) + { + bool notMatched = false; + for (size_t i = 0; i < aIndices.size(); i++) + { + Node aElement = TupleUtils::nthElementOfTuple(a, aIndices[i]); + Node bElement = TupleUtils::nthElementOfTuple(b, bIndices[i]); + if (aElement != bElement) + { + notMatched = true; + } + } + if (notMatched) + { + continue; + } + Node element = TupleUtils::concatTuples(tupleType, a, b); + newSet.insert(element); + } + } + + Node ret = NormalForm::elementsToSet(newSet, n.getType()); + + return RewriteResponse(REWRITE_AGAIN_FULL, ret); + } + return RewriteResponse(REWRITE_DONE, n); +} + RewriteResponse TheorySetsRewriter::postRewriteMap(TNode n) { Assert(n.getKind() == Kind::SET_MAP); diff --git a/src/theory/sets/theory_sets_rewriter.h b/src/theory/sets/theory_sets_rewriter.h index eb0b6457d99..b50b8e054f9 100644 --- a/src/theory/sets/theory_sets_rewriter.h +++ b/src/theory/sets/theory_sets_rewriter.h @@ -80,6 +80,7 @@ class TheorySetsRewriter : public TheoryRewriter * Rewrite set comprehension */ RewriteResponse postRewriteComprehension(TNode n); + RewriteResponse postRewriteTableJoin(TNode n); /** * rewrites for n include: * - (set.map f (as set.empty (Set T1)) = (as set.empty (Set T2)) diff --git a/src/theory/sets/theory_sets_type_rules.cpp b/src/theory/sets/theory_sets_type_rules.cpp index 14831f2d5d3..52f0d4ea578 100644 --- a/src/theory/sets/theory_sets_type_rules.cpp +++ b/src/theory/sets/theory_sets_type_rules.cpp @@ -19,10 +19,11 @@ #include "expr/dtype.h" #include "expr/dtype_cons.h" -#include "theory/sets/normal_form.h" -#include "util/cardinality.h" +#include "theory/bags/bags_utils.h" #include "theory/datatypes/project_op.h" #include "theory/datatypes/tuple_utils.h" +#include "theory/sets/normal_form.h" +#include "util/cardinality.h" namespace cvc5::internal { namespace theory { @@ -710,6 +711,97 @@ TypeNode RelBinaryOperatorTypeRule::computeType(NodeManager* nodeManager, return resultType; } +TypeNode RelationTableJoinTypeRule::preComputeType(NodeManager* nm, TNode n) +{ + return TypeNode::null(); +} +TypeNode RelationTableJoinTypeRule::computeType(NodeManager* nm, + TNode n, + bool check, + std::ostream* errOut) +{ + Assert(n.getKind() == Kind::RELATION_TABLE_JOIN && n.hasOperator() + && n.getOperator().getKind() == Kind::RELATION_TABLE_JOIN_OP); + ProjectOp op = n.getOperator().getConst(); + const std::vector& indices = op.getIndices(); + Node A = n[0]; + Node B = n[1]; + TypeNode aType = A.getType(); + TypeNode bType = B.getType(); + + if (check) + { + if (!(aType.isSet() && bType.isSet())) + { + std::stringstream ss; + ss << "RELATION_TABLE_JOIN operator expects two relations. Found '" + << n[0] << "', '" << n[1] << "' of types '" << aType << "', '" << bType + << "' respectively. "; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + TypeNode aTupleType = aType.getSetElementType(); + TypeNode bTupleType = bType.getSetElementType(); + if (!(aTupleType.isTuple() && bTupleType.isTuple())) + { + std::stringstream ss; + ss << "RELATION_TABLE_JOIN operator expects two relations. Found '" + << n[0] << "', '" << n[1] << "' of types '" << aType << "', '" << bType + << "' respectively. "; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + if (indices.size() % 2 != 0) + { + std::stringstream ss; + ss << "RELATION_TABLE_JOIN operator expects even number of indices. " + "Found " + << indices.size() << " in term " << n; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + auto [aIndices, bIndices] = bags::BagsUtils::splitTableJoinIndices(n); + if (!TupleUtils::checkTypeIndices(aTupleType, aIndices)) + { + if (errOut) + { + (*errOut) << "Index in operator of " << n + << " is out of range for the type of its first argument"; + } + return TypeNode::null(); + } + if (!TupleUtils::checkTypeIndices(bTupleType, bIndices)) + { + if (errOut) + { + (*errOut) << "Index in operator of " << n + << " is out of range for the type of its second argument"; + } + return TypeNode::null(); + } + + // check the types of columns + std::vector aTypes = aTupleType.getTupleTypes(); + std::vector bTypes = bTupleType.getTupleTypes(); + for (uint32_t i = 0; i < aIndices.size(); i++) + { + if (aTypes[aIndices[i]] != bTypes[bIndices[i]]) + { + std::stringstream ss; + ss << "RELATION_TABLE_JOIN operator expects column " << aIndices[i] + << " in relation " << n[0] << " to match column " << bIndices[i] + << " in relation " << n[1] << ". But their types are " + << aTypes[aIndices[i]] << " and " << bTypes[bIndices[i]] + << "' respectively. "; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + } + TypeNode aTupleType = aType.getSetElementType(); + TypeNode bTupleType = bType.getSetElementType(); + TypeNode retTupleType = TupleUtils::concatTupleTypes(aTupleType, bTupleType); + return nm->mkSetType(retTupleType); +} + TypeNode RelTransposeTypeRule::preComputeType(NodeManager* nm, TNode n) { return TypeNode::null(); diff --git a/src/theory/sets/theory_sets_type_rules.h b/src/theory/sets/theory_sets_type_rules.h index c29397ea937..d8b6ffa5bd9 100644 --- a/src/theory/sets/theory_sets_type_rules.h +++ b/src/theory/sets/theory_sets_type_rules.h @@ -252,6 +252,24 @@ struct RelBinaryOperatorTypeRule std::ostream* errOut); }; +/** + * Relation table join operator is indexed by a list of indices (m_1, m_k, n_1, + * ..., n_k). It ensures that it has 2 arguments: + * - A relation of type (Relation X_1 ... X_i) + * - A relation of type (Relation Y_1 ... Y_j) + * such that indices has constraints 0 <= m_1, ..., mk, n_1, ..., n_k <= + * min(i,j) and types has constraints X_{m_1} = Y_{n_1}, ..., X_{m_k} = Y_{n_k}. + * The returned type is (Relation X_1 ... X_i Y_1 ... Y_j) + */ +struct RelationTableJoinTypeRule +{ + static TypeNode preComputeType(NodeManager* nm, TNode n); + static TypeNode computeType(NodeManager* nodeManager, + TNode n, + bool check, + std::ostream* errOut); +}; /* struct RelationTableJoinTypeRule */ + /** * Type rule for unary operator (rel.transpose A) to check that A is a relation * (set of Tuples). For an argument A of type (Relation A1 ... An) diff --git a/src/theory/theory_engine.cpp b/src/theory/theory_engine.cpp index 26af039db5f..053796de08b 100644 --- a/src/theory/theory_engine.cpp +++ b/src/theory/theory_engine.cpp @@ -2062,6 +2062,11 @@ void TheoryEngine::checkTheoryAssertionsWithModel(bool hardFailure) { if (val != d_true) { std::stringstream ss; + for (Node child : assertion) + { + Node value = d_tc->getModel()->getValue(child); + ss << "getValue(" << child << "): " << value << std::endl; + } ss << " " << theoryId << " has an asserted fact that"; if (val == d_false) { diff --git a/test/regress/cli/CMakeLists.txt b/test/regress/cli/CMakeLists.txt index d8a4b1cdbf0..6a455e33b35 100644 --- a/test/regress/cli/CMakeLists.txt +++ b/test/regress/cli/CMakeLists.txt @@ -2863,6 +2863,7 @@ set(regress_1_tests regress1/rels/rel_tc_6.cvc.smt2 regress1/rels/rel_tc_9_1.cvc.smt2 regress1/rels/rel_tp_2.cvc.smt2 + regress1/rels/relation_join1.smt2 regress1/rels/rel_tp_join_2_1.cvc.smt2 regress1/rels/set-strat.cvc.smt2 regress1/rels/strat.cvc.smt2 diff --git a/test/regress/cli/regress1/rels/relation_join1.smt2 b/test/regress/cli/regress1/rels/relation_join1.smt2 new file mode 100644 index 00000000000..6ce8c75c1a4 --- /dev/null +++ b/test/regress/cli/regress1/rels/relation_join1.smt2 @@ -0,0 +1,27 @@ +(set-logic HO_ALL) + +(set-info :status sat) + +(declare-fun Departments () (Relation Int String)) +(declare-fun Students () (Relation Int String Int)) +(declare-fun DepartmentStudents () (Relation Int String Int String Int)) + +(declare-fun d1 () (Tuple Int String)) +(declare-fun d2 () (Tuple Int String)) +(assert (distinct d1 d2)) + +(declare-fun s1 () (Tuple Int String Int)) +(declare-fun s2 () (Tuple Int String Int)) +(assert (distinct s1 s2)) + +(assert + (distinct DepartmentStudents (as set.empty (Relation Int String Int String Int)))) + +(assert (set.member d1 Departments)) +(assert (set.member d2 Departments)) +(assert (set.member s1 Students)) +(assert (set.member s2 Students)) + +(assert (= DepartmentStudents ((_ rel.table_join 0 2) Departments Students))) + +(check-sat) From 278eb5a35072fba8140132dcef66565a9781ac02 Mon Sep 17 00:00:00 2001 From: Daniel Larraz Date: Mon, 22 Apr 2024 13:54:14 -0500 Subject: [PATCH 5/6] Install DLL libraries in the runtime path (#10649) This installs DLL libraries in the runtime path (the bin directory by default), which is common practice on Windows. --- .github/actions/run-tests/action.yml | 2 +- cmake/FindPoly.cmake | 8 ++++---- cmake/deps-helper.cmake | 7 +++++++ src/CMakeLists.txt | 7 +++---- src/parser/CMakeLists.txt | 6 +++--- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index 750c5310bbb..fc708b7b829 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -56,7 +56,7 @@ runs: cd examples mkdir -p build && cd build if [[ "$RUNNER_OS" == "Windows" ]]; then - export PATH="${{ inputs.build-dir }}/install/lib:$PATH" + export PATH="${{ inputs.build-dir }}/install/bin:$PATH" export CMAKE_GENERATOR="MSYS Makefiles" fi cmake .. -DCMAKE_PREFIX_PATH=${{ inputs.build-dir }}/install/lib/cmake \ diff --git a/cmake/FindPoly.cmake b/cmake/FindPoly.cmake index a74e41deea9..fb76e38bbd1 100644 --- a/cmake/FindPoly.cmake +++ b/cmake/FindPoly.cmake @@ -237,8 +237,8 @@ else() ExternalProject_Get_Property(Poly-EP BUILD_BYPRODUCTS INSTALL_DIR) string(REPLACE "" "${INSTALL_DIR}" BUILD_BYPRODUCTS "${BUILD_BYPRODUCTS}") - install(FILES - ${BUILD_BYPRODUCTS} - DESTINATION ${CMAKE_INSTALL_LIBDIR} - ) + # Only install shared libraries + if (BUILD_SHARED_LIBS) + install(FILES ${BUILD_BYPRODUCTS} TYPE ${LIB_BUILD_TYPE}) + endif() endif() diff --git a/cmake/deps-helper.cmake b/cmake/deps-helper.cmake index f9ff08537f7..115fa4e1aed 100644 --- a/cmake/deps-helper.cmake +++ b/cmake/deps-helper.cmake @@ -39,6 +39,13 @@ if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.14") ) endif() +# On Windows, DLL libraries are runtime artifacts +if(BUILD_SHARED_LIBS AND WIN32) + set(LIB_BUILD_TYPE BIN) +else() + set(LIB_BUILD_TYPE LIB) +endif() + # On Windows, we need to have a shell interpreter to call 'configure' if(CMAKE_SYSTEM_NAME STREQUAL "Windows") find_program (SHELL "sh" REQUIRED) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6fe763e7306..db0848524aa 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1456,10 +1456,9 @@ target_include_directories(cvc5 $ $ ) -install(TARGETS cvc5 - EXPORT cvc5-targets - DESTINATION ${CMAKE_INSTALL_LIBDIR} -) +# On Windows, CMake's default install action places +# DLLs into the runtime path (by default "bin") +install(TARGETS cvc5 EXPORT cvc5-targets) if(BUILD_SHARED_LIBS) set_target_properties(cvc5 PROPERTIES SOVERSION ${CVC5_SOVERSION}) diff --git a/src/parser/CMakeLists.txt b/src/parser/CMakeLists.txt index 66c8b054630..0d308412063 100644 --- a/src/parser/CMakeLists.txt +++ b/src/parser/CMakeLists.txt @@ -82,9 +82,9 @@ endif() set_target_properties(cvc5parser PROPERTIES OUTPUT_NAME cvc5parser) target_link_libraries(cvc5parser PRIVATE cvc5) -install(TARGETS cvc5parser - EXPORT cvc5-targets - DESTINATION ${CMAKE_INSTALL_LIBDIR}) +# On Windows, CMake's default install action places +# DLLs into the runtime path (by default "bin") +install(TARGETS cvc5parser EXPORT cvc5-targets) # The generated lexer/parser files define some functions as # __declspec(dllexport) via the ANTLR3_API macro, which leads to lots of From d0bb0af4e066b877355ff66bdd98c7de2e9839c2 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Mon, 22 Apr 2024 15:25:23 -0500 Subject: [PATCH 6/6] Add new infastructure for proofs in the prop layer (#10324) This PR adds the following functionalities: (1) Proof production when using --sat-solver=cadical, which by default is via logging DRAT proofs via cadical's interface. (2) Configurable proof production modes via --prop-proof-mode=X, which is PROOF by default when minisat and SKETCH by default when cadical. For the latter, a third mode SAT_EXTERNAL_PROVE (regardless of whether cadical and minisat is used) is available which dumps the dimacs of the computed unsat core of assumptions + input + theory lemmas, where for cadical this is "reproven" internally using another copy of cadical, and for minisat uses the SAT leaves of its proof. Note that the following changes were required: (1) PropEngine::assertInternal is updated to handle the case where the unsat core mode is assumptions and proofs are enabled (in which case we need to go through ProofCnfStream), (2) SAT proofs are cloned when the prop proof mode is PROOF and we connect CNF proofs. This is required since the SAT proof can be used for multiple purposes. Previously, the implementation did an unintuitive way of cloning the proof when connectCnf=false by "unconnecting" the previous CNF connected SAT proof (https://github.com/cvc5/cvc5/pull/10324/files#diff-d0ebcf17fbc068083603c6006907b2a9ead7f98a5391e470d517ed4720818600L181). This simplifies the logic. This may be a slight hit to performance but simplifies the logic significantly. Also, the previous hack wasn't exactly correct, since it would have been possible to introduce aliasing (internal proof nodes that prove the same thing as input formulas), in which case the proof would be unconnected too much. Instead, now we treat the original SAT proof as something that should not be modified (https://github.com/cvc5/cvc5/pull/10324/files#diff-d0ebcf17fbc068083603c6006907b2a9ead7f98a5391e470d517ed4720818600R380). This was already done on main in 9cbc5b9, --- src/options/proof_options.toml | 26 +++ src/prop/cadical.cpp | 3 +- src/prop/prop_engine.cpp | 51 ++++-- src/prop/prop_proof_manager.cpp | 312 +++++++++++++++++++++++++++++++- src/prop/prop_proof_manager.h | 53 +++++- src/smt/set_defaults.cpp | 55 +++++- 6 files changed, 464 insertions(+), 36 deletions(-) diff --git a/src/options/proof_options.toml b/src/options/proof_options.toml index 0bc89de7fe4..b973e08ab98 100644 --- a/src/options/proof_options.toml +++ b/src/options/proof_options.toml @@ -219,3 +219,29 @@ name = "Proof" type = "bool" default = "false" help = "Print the DRAT proof in binary format" + +[[option]] + name = "satProofMinDimacs" + category = "expert" + long = "sat-proof-min-dimacs" + type = "bool" + default = "true" + help = "Minimize the DIMACs emitted when prop-proof-mode is set to sat-external-prove" + +[[option]] + name = "propProofMode" + category = "regular" + long = "prop-proof-mode=MODE" + type = "PropProofMode" + default = "PROOF" + help = "modes for proof granularity" + help_mode = "Modes for proof granularity." +[[option.mode.PROOF]] + name = "proof" + help = "A proof computed by the SAT solver." +[[option.mode.SAT_EXTERNAL_PROVE]] + name = "sat-external-prove" + help = "A proof containing a step that will be proven externally." +[[option.mode.SKETCH]] + name = "sketch" + help = "A sketch given by the SAT solver." diff --git a/src/prop/cadical.cpp b/src/prop/cadical.cpp index cc76d97d618..fa0fcca7460 100644 --- a/src/prop/cadical.cpp +++ b/src/prop/cadical.cpp @@ -1253,7 +1253,8 @@ std::vector CadicalSolver::getOrderHeap() const { return {}; } std::shared_ptr CadicalSolver::getProof() { - Unimplemented() << "getProof for CaDiCaL not supported"; + // do not throw an exception, since we test whether the proof is available + // by comparing it to nullptr. return nullptr; } diff --git a/src/prop/prop_engine.cpp b/src/prop/prop_engine.cpp index e2aed0104f8..a28a2f1f6b5 100644 --- a/src/prop/prop_engine.cpp +++ b/src/prop/prop_engine.cpp @@ -83,16 +83,19 @@ PropEngine::PropEngine(Env& env, TheoryEngine* te) Trace("prop") << "Constructing the PropEngine" << std::endl; context::UserContext* userContext = d_env.getUserContext(); - if (options().prop.satSolver == options::SatSolverMode::MINISAT - || d_env.isSatProofProducing()) + if (options().prop.satSolver == options::SatSolverMode::MINISAT) { d_satSolver = SatSolverFactory::createCDCLTMinisat(d_env, statisticsRegistry()); } else { + // log DRAT proofs if the mode is SKETCH. + bool logProofs = + (env.isSatProofProducing() + && options().proof.propProofMode == options::PropProofMode::SKETCH); d_satSolver = SatSolverFactory::createCadicalCDCLT( - d_env, statisticsRegistry(), env.getResourceManager()); + d_env, statisticsRegistry(), env.getResourceManager(), "", logProofs); } // CNF stream and theory proxy required pointers to each other, make the @@ -110,7 +113,8 @@ PropEngine::PropEngine(Env& env, TheoryEngine* te) bool satProofs = d_env.isSatProofProducing(); if (satProofs) { - d_ppm.reset(new PropPfManager(env, d_satSolver, *d_cnfStream)); + d_ppm.reset( + new PropPfManager(env, d_satSolver, *d_cnfStream, d_assumptions)); } // connect SAT solver d_satSolver->initialize( @@ -252,34 +256,43 @@ void PropEngine::assertInternal(theory::InferenceId id, bool input, ProofGenerator* pg) { - // Assert as (possibly) removable - if (options().smt.unsatCoresMode == options::UnsatCoresMode::ASSUMPTIONS) + bool addAssumption = false; + if (isProofEnabled()) { - if (input) + if (input + && options().smt.unsatCoresMode == options::UnsatCoresMode::ASSUMPTIONS) { - d_cnfStream->ensureLiteral(node); - if (negated) - { - d_assumptions.push_back(node.notNode()); - } - else - { - d_assumptions.push_back(node); - } + // use the proof CNF stream to ensure the literal + d_ppm->ensureLiteral(node); + addAssumption = true; } else { - d_cnfStream->convertAndAssert(node, removable, negated); + d_ppm->convertAndAssert(id, node, negated, removable, input, pg); } } - else if (isProofEnabled()) + else if (input + && options().smt.unsatCoresMode + == options::UnsatCoresMode::ASSUMPTIONS) { - d_ppm->convertAndAssert(id, node, negated, removable, input, pg); + d_cnfStream->ensureLiteral(node); + addAssumption = true; } else { d_cnfStream->convertAndAssert(node, removable, negated); } + if (addAssumption) + { + if (negated) + { + d_assumptions.push_back(node.notNode()); + } + else + { + d_assumptions.push_back(node); + } + } } void PropEngine::assertLemmasInternal( diff --git a/src/prop/prop_proof_manager.cpp b/src/prop/prop_proof_manager.cpp index c04cade610c..4f5c9ca49df 100644 --- a/src/prop/prop_proof_manager.cpp +++ b/src/prop/prop_proof_manager.cpp @@ -15,21 +15,27 @@ #include "prop/prop_proof_manager.h" +#include "expr/skolem_manager.h" +#include "options/main_options.h" +#include "printer/printer.h" #include "proof/proof_ensure_closed.h" #include "proof/proof_node_algorithm.h" #include "proof/theory_proof_step_buffer.h" #include "prop/cnf_stream.h" -#include "prop/prop_proof_manager.h" #include "prop/minisat/sat_proof_manager.h" +#include "prop/prop_proof_manager.h" #include "prop/sat_solver.h" +#include "prop/sat_solver_factory.h" #include "smt/env.h" +#include "util/string.h" namespace cvc5::internal { namespace prop { PropPfManager::PropPfManager(Env& env, CDCLTSatSolver* satSolver, - CnfStream& cnf) + CnfStream& cnf, + const context::CDList& assumptions) : EnvObj(env), d_propProofs(userContext()), // Since the ProofCnfStream performs no equality reasoning, there is no @@ -54,6 +60,7 @@ PropPfManager::PropPfManager(Env& env, d_satSolver(satSolver), d_assertions(userContext()), d_cnfStream(cnf), + d_assumptions(assumptions), d_inputClauses(userContext()), d_lemmaClauses(userContext()), d_satPm(nullptr) @@ -114,12 +121,17 @@ std::vector PropPfManager::getUnsatCoreLemmas() { std::vector usedLemmas; std::vector allLemmas = getLemmaClauses(); - std::shared_ptr satPf = getProof(false); - std::vector satLeaves; - expr::getFreeAssumptions(satPf.get(), satLeaves); + // compute the unsat core clauses, as below + std::vector ucc = getUnsatCoreClauses(); + Trace("prop-pf") << "Compute unsat core lemmas from " << ucc.size() + << " clauses (of " << allLemmas.size() << " lemmas)" + << std::endl; + Trace("prop-pf") << "lemmas: " << allLemmas << std::endl; + Trace("prop-pf") << "uc: " << ucc << std::endl; + // filter to only those corresponding to lemmas for (const Node& lemma : allLemmas) { - if (std::find(satLeaves.begin(), satLeaves.end(), lemma) != satLeaves.end()) + if (std::find(ucc.begin(), ucc.end(), lemma) != ucc.end()) { usedLemmas.push_back(lemma); } @@ -127,6 +139,175 @@ std::vector PropPfManager::getUnsatCoreLemmas() return usedLemmas; } +std::vector PropPfManager::getMinimizedAssumptions() +{ + std::vector minAssumptions; + std::vector unsatAssumptions; + d_satSolver->getUnsatAssumptions(unsatAssumptions); + for (const Node& nc : d_assumptions) + { + if (nc.isConst()) + { + if (nc.getConst()) + { + // never include true + continue; + } + minAssumptions.clear(); + minAssumptions.push_back(nc); + return minAssumptions; + } + else if (d_pfCnfStream.hasLiteral(nc)) + { + SatLiteral il = d_pfCnfStream.getLiteral(nc); + if (std::find(unsatAssumptions.begin(), unsatAssumptions.end(), il) + == unsatAssumptions.end()) + { + continue; + } + } + else + { + Assert(false) << "Missing literal for assumption " << nc; + } + minAssumptions.push_back(nc); + } + return minAssumptions; +} + +std::vector PropPfManager::getUnsatCoreClauses(std::ostream* outDimacs) +{ + std::vector uc; + // if it has a proof + std::shared_ptr satPf = d_satSolver->getProof(); + if (satPf != nullptr) + { + // then, get the proof *without* connecting the CNF + expr::getFreeAssumptions(satPf.get(), uc); + if (outDimacs != nullptr) + { + std::vector auxUnits = computeAuxiliaryUnits(uc); + d_pfCnfStream.dumpDimacs(*outDimacs, uc, auxUnits); + // include the auxiliary units if any + uc.insert(uc.end(), auxUnits.begin(), auxUnits.end()); + } + return uc; + } + // otherwise we need to compute it + // as a minor optimization, we use only minimized assumptions + std::vector minAssumptions = getMinimizedAssumptions(); + std::unordered_set cset(minAssumptions.begin(), minAssumptions.end()); + std::vector inputs = getInputClauses(); + std::vector lemmas = getLemmaClauses(); + cset.insert(inputs.begin(), inputs.end()); + cset.insert(lemmas.begin(), lemmas.end()); + if (!reproveUnsatCore(cset, uc, outDimacs)) + { + // otherwise, must include all + return getLemmaClauses(); + } + return uc; +} + +bool PropPfManager::reproveUnsatCore(const std::unordered_set& cset, + std::vector& uc, + std::ostream* outDimacs) +{ + CDCLTSatSolver* csm = SatSolverFactory::createCadical( + d_env, statisticsRegistry(), d_env.getResourceManager(), ""); + NullRegistrar nreg; + context::Context nctx; + CnfStream csms(d_env, csm, &nreg, &nctx); + Trace("cnf-input-min") << "Get literals..." << std::endl; + std::vector csma; + std::map litToNode; + std::map litToNodeAbs; + NodeManager* nm = NodeManager::currentNM(); + TypeNode bt = nm->booleanType(); + TypeNode ft = nm->mkFunctionType({bt}, bt); + SkolemManager* skm = nm->getSkolemManager(); + // Function used to ensure that subformulas are not treated by CNF below. + Node litOf = skm->mkDummySkolem("litOf", ft); + for (const Node& c : cset) + { + Node ca = c; + std::vector satClause; + std::vector lits; + if (c.getKind() == Kind::OR) + { + lits.insert(lits.end(), c.begin(), c.end()); + } + else + { + lits.push_back(c); + } + // For each literal l in the current clause, if it has Boolean + // substructure, we replace it with (litOf l), which will be treated as a + // literal. We do this since we require that the clause be treated + // verbatim by the SAT solver, otherwise the unsat core will not include + // the necessary clauses (e.g. it will skip those corresponding to CNF + // conversion). + std::vector cls; + bool childChanged = false; + for (const Node& cl : lits) + { + bool negated = cl.getKind() == Kind::NOT; + Node cla = negated ? cl[0] : cl; + if (d_env.theoryOf(cla) == theory::THEORY_BOOL && !cla.isVar()) + { + Node k = nm->mkNode(Kind::APPLY_UF, {litOf, cla}); + cls.push_back(negated ? k.notNode() : k); + childChanged = true; + } + else + { + cls.push_back(cl); + } + } + if (childChanged) + { + ca = nm->mkOr(cls); + } + Trace("cnf-input-min-assert") << "Assert: " << ca << std::endl; + csms.ensureLiteral(ca); + SatLiteral lit = csms.getLiteral(ca); + csma.emplace_back(lit); + litToNode[lit] = c; + litToNodeAbs[lit] = ca; + } + Trace("cnf-input-min") << "Solve under " << csma.size() << " assumptions..." + << std::endl; + SatValue res = csm->solve(csma); + if (res == SAT_VALUE_FALSE) + { + // we successfully reproved the input + Trace("cnf-input-min") << "...got unsat" << std::endl; + std::vector uassumptions; + csm->getUnsatAssumptions(uassumptions); + Trace("cnf-input-min") << "...#unsat assumptions=" << uassumptions.size() + << std::endl; + std::vector aclauses; + for (const SatLiteral& lit : uassumptions) + { + Assert(litToNode.find(lit) != litToNode.end()); + Trace("cnf-input-min-result") + << "assert: " << litToNode[lit] << std::endl; + uc.emplace_back(litToNode[lit]); + aclauses.emplace_back(litToNodeAbs[lit]); + } + if (outDimacs) + { + // dump using the CNF stream we created above + csms.dumpDimacs(*outDimacs, aclauses); + } + return true; + } + // should never happen, if it does, we revert to the entire input + Trace("cnf-input-min") << "...got sat" << std::endl; + Assert(false) << "Failed to minimize DIMACS"; + return false; +} + std::vector> PropPfManager::getProofLeaves( modes::ProofComponent pc) { @@ -161,9 +342,25 @@ std::shared_ptr PropPfManager::getProof(bool connectCnf) return it->second; } // retrieve the SAT solver's refutation proof - Trace("sat-proof") - << "PropPfManager::getProof: Getting resolution proof of false\n"; - std::shared_ptr conflictProof = d_satSolver->getProof(); + Trace("sat-proof") << "PropPfManager::getProof: Getting proof of false\n"; + + // get the proof based on the proof mode + options::PropProofMode pmode = options().proof.propProofMode; + std::shared_ptr conflictProof; + if (pmode == options::PropProofMode::PROOF) + { + // take proof from SAT solver as is + conflictProof = d_satSolver->getProof(); + } + else + { + // set up a proof and get the internal proof + CDProof cdp(d_env); + getProofInternal(&cdp); + Node falsen = NodeManager::currentNM()->mkConst(false); + conflictProof = cdp.getProofFor(falsen); + } + Assert(conflictProof); if (TraceIsOn("sat-proof")) { @@ -187,7 +384,10 @@ std::shared_ptr PropPfManager::getProof(bool connectCnf) } // Must clone if we are using the original proof, since we don't want to // modify the original SAT proof. - conflictProof = conflictProof->clone(); + if (pmode == options::PropProofMode::PROOF) + { + conflictProof = conflictProof->clone(); + } // connect it with CNF proof d_pfpp->process(conflictProof); if (TraceIsOn("sat-proof")) @@ -253,6 +453,98 @@ Node PropPfManager::normalizeAndRegister(TNode clauseNode, LazyCDProof* PropPfManager::getCnfProof() { return &d_proof; } +void PropPfManager::getProofInternal(CDProof* cdp) +{ + // This method is called when the SAT solver did not generate a fully self + // contained ProofNode proving false. This method adds a step to cdp + // based on a set of computed assumptions, possibly relying on the internal + // proof. + NodeManager* nm = NodeManager::currentNM(); + Node falsen = nm->mkConst(false); + std::vector clauses; + // deduplicate assumptions + Trace("cnf-input") << "#assumptions=" << d_assumptions.size() << std::endl; + std::vector minAssumptions = getMinimizedAssumptions(); + if (minAssumptions.size() == 1 && minAssumptions[0] == falsen) + { + // if false exists, no proof is necessary + return; + } + std::unordered_set cset(minAssumptions.begin(), minAssumptions.end()); + Trace("cnf-input") << "#assumptions (min)=" << cset.size() << std::endl; + std::vector inputs = getInputClauses(); + Trace("cnf-input") << "#input=" << inputs.size() << std::endl; + std::vector lemmas = getLemmaClauses(); + Trace("cnf-input") << "#lemmas=" << lemmas.size() << std::endl; + cset.insert(inputs.begin(), inputs.end()); + cset.insert(lemmas.begin(), lemmas.end()); + + // Otherwise, we will dump a DIMACS. The proof further depends on the + // mode, which we handle below. + std::stringstream dinputFile; + dinputFile << options().driver.filename << ".drat_input.cnf"; + // the stream which stores the DIMACS of the computed clauses + std::fstream dout(dinputFile.str(), std::ios::out); + options::PropProofMode pmode = options().proof.propProofMode; + // minimize only if SAT_EXTERNAL_PROVE and satProofMinDimacs is true. + bool minimal = (pmode == options::PropProofMode::SAT_EXTERNAL_PROVE + && options().proof.satProofMinDimacs); + // go back and minimize assumptions if minimal is true + bool computedClauses = false; + if (minimal) + { + // get the unsat core clauses + std::shared_ptr satPf = d_satSolver->getProof(); + if (satPf != nullptr) + { + clauses = getUnsatCoreClauses(&dout); + computedClauses = true; + } + else if (reproveUnsatCore(cset, clauses, &dout)) + { + computedClauses = true; + } + else + { + // failed to reprove + } + } + // if we did not minimize, just include all + if (!computedClauses) + { + // if no minimization is necessary, just include all + clauses.insert(clauses.end(), cset.begin(), cset.end()); + std::vector auxUnits = computeAuxiliaryUnits(clauses); + d_pfCnfStream.dumpDimacs(dout, clauses, auxUnits); + // include the auxiliary units if any + clauses.insert(clauses.end(), auxUnits.begin(), auxUnits.end()); + } + // construct the proof + std::vector args; + Node dfile = nm->mkConst(String(dinputFile.str())); + args.push_back(dfile); + ProofRule r = ProofRule::UNKNOWN; + if (pmode == options::PropProofMode::SKETCH) + { + // if sketch, get the rule and arguments from the SAT solver. + std::pair> sk = d_satSolver->getProofSketch(); + r = sk.first; + args.insert(args.end(), sk.second.begin(), sk.second.end()); + } + else if (pmode == options::PropProofMode::SAT_EXTERNAL_PROVE) + { + // if SAT_EXTERNAL_PROVE, the rule is fixed and there are no additional + // arguments. + r = ProofRule::SAT_EXTERNAL_PROVE; + } + else + { + Assert(false) << "Unknown proof mode " << pmode; + } + // use the rule, clauses and arguments we computed above + cdp->addStep(falsen, r, clauses, args); +} + std::vector PropPfManager::computeAuxiliaryUnits( const std::vector& clauses) { diff --git a/src/prop/prop_proof_manager.h b/src/prop/prop_proof_manager.h index d561b545750..4fb35369d1f 100644 --- a/src/prop/prop_proof_manager.h +++ b/src/prop/prop_proof_manager.h @@ -48,7 +48,16 @@ class PropPfManager : protected EnvObj friend class SatProofManager; public: - PropPfManager(Env& env, CDCLTSatSolver* satSolver, CnfStream& cnfProof); + /** + * @param env The environment + * @param satSolver Pointer to the SAT solver + * @param cnfProof Pointer to the CNF stream + * @param assumptions Reference to assumptions of parent prop engine + */ + PropPfManager(Env& env, + CDCLTSatSolver* satSolver, + CnfStream& cnfProof, + const context::CDList& assumptions); /** * Ensure that the given node will have a designated SAT literal that is * definitionally equal to it. The result of this function is that the Node @@ -169,6 +178,46 @@ class PropPfManager : protected EnvObj std::vector getInputClauses(); /** Retrieve the clauses derived from lemmas */ std::vector getLemmaClauses(); + /** + * Return theory lemmas used for showing unsat. If the SAT solver has a proof, + * we examine its leaves. Otherwise, we recompute the unsat core lemmas + * using the method reproveUnsatCore. + * + * @param outDimacs If provided, we write the DIMACS output of uc to this + * stream + * @return the unsat core of lemmas. + */ + std::vector getUnsatCoreClauses(std::ostream* outDimacs = nullptr); + /** + * Get minimized assumptions. Returns a vector of nodes which is a + * subset of the assumptions (d_assumptions) that appear in the unsat + * core. This should be called only when the unsat core is available (after + * an unsatisfiable check-sat). + */ + std::vector getMinimizedAssumptions(); + /** + * Calculate a subset of cset that is propositionally unsatisfiable. + * If sucessful, return true and store this in uc. + * + * @param cset The set of formulas to compute an unsat core for + * @param uc The set of formulas returned as the unsat core + * @param outDimacs If provided, we write a DIMACS representation of uc to + * this stream + */ + bool reproveUnsatCore(const std::unordered_set& cset, + std::vector& uc, + std::ostream* outDimacs = nullptr); + /** + * Add a proof of false to cdp whose free assumptions are a subset of the + * clauses (after CNF conversion), which is a union of: + * (1) assumptions (d_assumptions), + * (2) input clauses (d_inputClauses), + * (3) lemma clauses (d_lemmaClauses). + * The choice of what to add to cdp is dependent on the prop-proof-mode. + * + * @param cdp The proof object to add the refutation proof to. + */ + void getProofInternal(CDProof* cdp); /** * Get auxilary units. Computes top-level formulas in clauses that * also occur as literals which we call "auxiliary units". In particular, @@ -215,6 +264,8 @@ class PropPfManager : protected EnvObj context::CDList d_assertions; /** The cnf stream proof generator */ CnfStream& d_cnfStream; + /** Reference to the assumptions of the parent prop engine */ + const context::CDList& d_assumptions; /** Asserted clauses derived from the input */ context::CDHashSet d_inputClauses; /** Asserted clauses derived from lemmas */ diff --git a/src/smt/set_defaults.cpp b/src/smt/set_defaults.cpp index 3a0b78e2e47..1fd68f8d191 100644 --- a/src/smt/set_defaults.cpp +++ b/src/smt/set_defaults.cpp @@ -173,10 +173,23 @@ void SetDefaults::setDefaultsPre(Options& opts) if (opts.smt.unsatCoresMode != options::UnsatCoresMode::SAT_PROOF) { SET_AND_NOTIFY(Smt, produceUnsatCores, true, "enabling proofs"); - SET_AND_NOTIFY(Smt, - unsatCoresMode, - options::UnsatCoresMode::SAT_PROOF, - "enabling proofs"); + if (options().prop.satSolver == options::SatSolverMode::MINISAT) + { + // if full proofs are available in minisat, use them for unsat cores + SET_AND_NOTIFY(Smt, + unsatCoresMode, + options::UnsatCoresMode::SAT_PROOF, + "enabling proofs, minisat"); + } + else if (options().prop.satSolver == options::SatSolverMode::CADICAL) + { + // unsat cores available by assumptions by default if proofs are enabled + // with CaDiCaL. + SET_AND_NOTIFY(Smt, + unsatCoresMode, + options::UnsatCoresMode::ASSUMPTIONS, + "enabling proofs, non-minisat"); + } } // note that this test assumes that granularity modes are ordered and // THEORY_REWRITE is gonna be, in the enum, after the lower granularity @@ -239,6 +252,21 @@ void SetDefaults::setDefaultsPre(Options& opts) } } } + if (opts.smt.produceProofs) + { + // determine the prop proof mode, based on which SAT solver we are using + if (!opts.proof.propProofModeWasSetByUser) + { + if (opts.prop.satSolver == options::SatSolverMode::CADICAL) + { + // use SAT_EXTERNAL_PROVE for cadical by default + SET_AND_NOTIFY(Proof, + propProofMode, + options::PropProofMode::SAT_EXTERNAL_PROVE, + "cadical"); + } + } + } // if unsat cores are disabled, then unsat cores mode should be OFF. Similarly // for proof mode. @@ -247,7 +275,7 @@ void SetDefaults::setDefaultsPre(Options& opts) Assert(opts.smt.produceProofs == (opts.smt.proofMode != options::ProofMode::OFF)); - // if we requiring disabling proofs, disable them now + // if we require disabling options due to proofs, disable them now if (opts.smt.produceProofs) { std::stringstream reasonNoProofs; @@ -977,6 +1005,23 @@ bool SetDefaults::incompatibleWithProofs(Options& opts, reason << "deep restarts"; return true; } + // specific to SAT solver + if (opts.prop.satSolver == options::SatSolverMode::CADICAL) + { + if (opts.proof.propProofMode == options::PropProofMode::PROOF) + { + reason << "(resolution) proofs not supported in cadical"; + return true; + } + } + else if (opts.prop.satSolver == options::SatSolverMode::MINISAT) + { + if (opts.proof.propProofMode == options::PropProofMode::SKETCH) + { + reason << "(DRAT) proof sketch not supported in minisat"; + return true; + } + } if (options().theory.lemmaInprocess != options::LemmaInprocessMode::NONE) { // lemma inprocessing introduces depencencies from learned unit literals