Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Boolean overloads: &, |, ^, ~ #103

Merged
merged 8 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
with:
options: "--check --verbose"
src: "cvc5_pythonic_api"
version: "23.7.0"
version: "24.10.0"

- uses: actions/checkout@v2
with:
Expand Down Expand Up @@ -56,7 +56,7 @@ jobs:
- name: Build cvc5
run: |
cd cvc5/
./configure.sh production --auto-download --python-bindings --cocoa
./configure.sh production --auto-download --python-bindings --cocoa --gpl
cd build/
make -j${{ env.num_proc }}

Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ check:
pyright ./cvc5_pythonic_api

fmt:
black --required-version 23.7.0 ./cvc5_pythonic_api
black --required-version 24 ./cvc5_pythonic_api

check-fmt:
black --check --verbose --required-version 23.7.0 ./cvc5_pythonic_api
black --check --verbose --required-version 24 ./cvc5_pythonic_api

coverage:
coverage run test_unit.py && coverage report && coverage html
121 changes: 90 additions & 31 deletions cvc5_pythonic_api/cvc5_pythonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
* Missing features:
* Patterns
* Models for uninterpreted sorts
* The `Model` function
* In our API, this function returns an object whose only method is `evaluate`.
* Pseudo-boolean counting constraints
* AtMost, AtLeast, PbLe, PbGe, PbEq
* HTML integration
Expand Down Expand Up @@ -558,9 +560,6 @@ def _ctx_from_ast_arg_list(args, default_ctx=None):
if is_ast(a):
if ctx is None:
ctx = a.ctx
else:
if debugging():
_assert(ctx == a.ctx, "Context mismatch")
if ctx is None:
ctx = default_ctx
return ctx
Expand Down Expand Up @@ -1245,8 +1244,6 @@ def If(a, b, c, ctx=None):
s = BoolSort(ctx)
a = s.cast(a)
b, c = _coerce_exprs(b, c, ctx)
if debugging():
_assert(a.ctx == b.ctx, "Context mismatch")
return _to_expr_ref(ctx.solver.mkTerm(Kind.ITE, a.ast, b.ast, c.ast), ctx)


Expand Down Expand Up @@ -1429,6 +1426,38 @@ def __mul__(self, other):
return 0
return If(self, other, 0)

def __and__(self, other):
"""Create the SMT and expression `self & other`.

>>> solve(Bool("x") & Bool("y"))
[x = True, y = True]
"""
return And(self, other)

def __or__(self, other):
"""Create the SMT or expression `self | other`.

>>> solve(Bool("x") | Bool("y"), Not(Bool("x")))
[x = False, y = True]
"""
return Or(self, other)

def __xor__(self, other):
"""Create the SMT xor expression `self ^ other`.

>>> solve(Bool("x") ^ Bool("y"), Not(Bool("x")))
[x = False, y = True]
"""
return Xor(self, other)

def __invert__(self):
"""Create the SMT not expression `~self`.

>>> solve(~Bool("x"))
[x = False]
"""
return Not(self)


def is_bool(a):
"""Return `True` if `a` is an SMT Boolean expression.
Expand Down Expand Up @@ -1875,8 +1904,6 @@ def cast(self, val):
String
"""
if is_expr(val):
if debugging():
_assert(self.ctx == val.ctx, "Context mismatch")
val_s = val.sort()
if self.eq(val_s):
return val
Expand Down Expand Up @@ -2617,8 +2644,6 @@ def cast(self, val):
failed
"""
if is_expr(val):
if debugging():
_assert(self.ctx == val.ctx, "Context mismatch")
val_s = val.sort()
if self.eq(val_s):
return val
Expand Down Expand Up @@ -4067,8 +4092,6 @@ def cast(self, val):
'#b00000000000000000000000000001010'
"""
if is_expr(val):
if debugging():
_assert(self.ctx == val.ctx, "Context mismatch")
# Idea: use sign_extend if sort of val is a bitvector of smaller size
return val
else:
Expand Down Expand Up @@ -5494,7 +5517,6 @@ def ArraySort(*sig):
if debugging():
for s in sig:
_assert(is_sort(s), "SMT sort expected")
_assert(s.ctx == r.ctx, "Context mismatch")
ctx = d.ctx
if len(sig) == 2:
return ArraySortRef(ctx.solver.mkArraySort(d.ast, r.ast), ctx)
Expand Down Expand Up @@ -6238,12 +6260,22 @@ def proof(self):
[a + 2 == 0, a == 0],
(EQ_RESOLVE: False,
(ASSUME: a == 0, [a == 0]),
(MACRO_SR_EQ_INTRO: (a == 0) == False,
[a == 0, 7, 12],
(EQ_RESOLVE: a == -2,
(ASSUME: a + 2 == 0, [a + 2 == 0]),
(MACRO_SR_EQ_INTRO: (a + 2 == 0) == (a == -2),
[a + 2 == 0, 7, 12]))))))
(TRANS: (a == 0) == False,
(CONG: (a == 0) == (-2 == 0),
[5],
(EQ_RESOLVE: a == -2,
(ASSUME: a + 2 == 0, [a + 2 == 0]),
(TRANS: (a + 2 == 0) == (a == -2),
(CONG: (a + 2 == 0) == (2 + a == 0),
[5],
(TRUST_THEORY_REWRITE: a + 2 == 2 + a,
[a + 2 == 2 + a, 3, 7]),
(REFL: 0 == 0, [0])),
(TRUST_THEORY_REWRITE: (2 + a == 0) == (a == -2),
[(2 + a == 0) == (a == -2), 3, 7]))),
(REFL: 0 == 0, [0])),
(TRUST_THEORY_REWRITE: (-2 == 0) == False,
[(-2 == 0) == False, 3, 7])))))
"""
p = self.solver.getProof()[0]
return ProofRef(self, p)
Expand Down Expand Up @@ -6789,13 +6821,36 @@ def decls(self):


def evaluate(t):
"""Evaluates the given term (assuming it is constant!)"""
"""Evaluates the given term (assuming it is constant!)

>>> evaluate(evaluate(BitVecVal(1, 8) + BitVecVal(2, 8)) + BitVecVal(3, 8))
6
"""
if not isinstance(t, ExprRef):
raise TypeError("Can only evaluation `ExprRef`s")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

s = Solver()
s.check()
m = s.model()
return m[t]


class EmptyModel:
def evaluate(self, t):
return evaluate(t)


def Model(ctx=None):
"""Return an object for evaluating terms.

We recommend using the standalone `evaluate` function for this instead,
but we also provide this function and its return object for z3 compatibility.

>>> Model().evaluate(BitVecVal(1, 8) + BitVecVal(2, 8))
3
"""
return EmptyModel()


class ProofRef:
"""A proof tree where every proof reference corresponds to the
root step of a proof. The branches of the root step are the
Expand Down Expand Up @@ -6857,12 +6912,22 @@ def getChildren(self):
>>> p
(EQ_RESOLVE: False,
(ASSUME: a == 0, [a == 0]),
(MACRO_SR_EQ_INTRO: (a == 0) == False,
[a == 0, 7, 12],
(EQ_RESOLVE: a == -2,
(ASSUME: a + 2 == 0, [a + 2 == 0]),
(MACRO_SR_EQ_INTRO: (a + 2 == 0) == (a == -2),
[a + 2 == 0, 7, 12]))))
(TRANS: (a == 0) == False,
(CONG: (a == 0) == (-2 == 0),
[5],
(EQ_RESOLVE: a == -2,
(ASSUME: a + 2 == 0, [a + 2 == 0]),
(TRANS: (a + 2 == 0) == (a == -2),
(CONG: (a + 2 == 0) == (2 + a == 0),
[5],
(TRUST_THEORY_REWRITE: a + 2 == 2 + a,
[a + 2 == 2 + a, 3, 7]),
(REFL: 0 == 0, [0])),
(TRUST_THEORY_REWRITE: (2 + a == 0) == (a == -2),
[(2 + a == 0) == (a == -2), 3, 7]))),
(REFL: 0 == 0, [0])),
(TRUST_THEORY_REWRITE: (-2 == 0) == False,
[(-2 == 0) == False, 3, 7])))
"""
children = self.proof.getChildren()
return [ProofRef(self.solver, cp) for cp in children]
Expand Down Expand Up @@ -6965,8 +7030,6 @@ def cast(self, val):
'(fp #b0 #b01111111 #b00000000000000000000000)'
"""
if is_expr(val):
if debugging():
_assert(self.ctx == val.ctx, "Context mismatch")
return val
else:
return FPVal(val, None, self, self.ctx)
Expand Down Expand Up @@ -8633,7 +8696,6 @@ def CreateDatatypes(*ds):
_assert(
all([isinstance(d, Datatype) for d in ds]), "Arguments must be Datatypes"
)
_assert(all([d.ctx == ds[0].ctx for d in ds]), "Context mismatch")
_assert(all([d.constructors != [] for d in ds]), "Non-empty Datatypes expected")
ctx = ds[0].ctx
s = ctx.solver
Expand Down Expand Up @@ -9240,9 +9302,6 @@ def cast(self, val):
'#f10m31'
"""
if is_expr(val):
if debugging():
_assert(self.ctx == val.ctx, "Context mismatch")
# Idea: use sign_extend if sort of val is a bitvector of smaller size
return val
else:
return FiniteFieldVal(val, self)
Expand Down
Loading