From bc45b284fe87dc3874fc438e3837d3312cdcfd0a Mon Sep 17 00:00:00 2001 From: Emiko Soroka Date: Tue, 28 Nov 2023 21:15:28 -0800 Subject: [PATCH 1/7] All tests pass. (1) Expanded ite(x,y,z) to allow non-Boolean values for y and z and simplified code. (2) Added promote_type and convert rules to codify promotion rules for IntExpr, BoolExpr and Real Expr. (3) Added integer division and cleaned up type conversion around integer and real-valued division. (4) Added to_real and to_int SMT-LIB functions. (5) Fixed issue#21 using (2) and (4) to ensure correct promotion by wrapping promoted BoolExprs using ite and wrapping promoted IntExprs using to_real. --- docs/src/functions.md | 9 ++- src/BooleanOperations.jl | 18 +++--- src/IntExpr.jl | 107 ++++++++++++++++++++++++-------- src/Satisfiability.jl | 8 ++- src/sat.jl | 2 +- src/smt_representation.jl | 4 +- src/utilities.jl | 4 +- test/boolean_operation_tests.jl | 4 +- test/int_real_tests.jl | 2 +- test/output_parse_tests.jl | 6 +- test/solver_interface_tests.jl | 4 +- 11 files changed, 118 insertions(+), 50 deletions(-) diff --git a/docs/src/functions.md b/docs/src/functions.md index 23e7542..31a6774 100644 --- a/docs/src/functions.md +++ b/docs/src/functions.md @@ -32,7 +32,7 @@ distinct(z1::BoolExpr, z2::BoolExpr) ## Arithmetic operations These are operations in the theory of integer and real-valued arithmetic. -Note that `+`, `-`, and `*` follow type promotion rules: if both `a` and `b` are `IntExpr`s, `a+b` will have type `IntExpr`. If either `a` or `b` is a `RealExpr`, the result will have type `RealExpr`. Division `\` is defined only in the theory of real-valued arithmetic, thus it always has return type `RealExpr`. +Note that `+`, `-`, and `*` follow type promotion rules: if both `a` and `b` are `IntExpr`s, `a+b` will have type `IntExpr`. If either `a` or `b` is a `RealExpr`, the result will have type `RealExpr`. Integer division `div(a,b)` is defined only for `IntExpr`s. Real-valued division `a\b` is defined only in the theory of real-valued arithmetic. For a formal definition of the theory of integer arithmetic, see Figure 3.3 in *The SMT-LIB Standard, Version 2.6*. ```@docs @@ -40,6 +40,7 @@ Base.:-(a::IntExpr) Base.:+(a::IntExpr, b::IntExpr) Base.:-(a::IntExpr, b::IntExpr) Base.:*(a::RealExpr, b::RealExpr) +Base.div(a::IntExpr, b::IntExpr) Base.:/(a::RealExpr, b::RealExpr) ``` @@ -57,6 +58,12 @@ Base.:>(a::IntExpr, b::IntExpr) Base.:>=(a::IntExpr, b::IntExpr) ``` +### Conversion operators +```@docs +to_int(a::RealExpr) +to_real(a::IntExpr) +``` + ## BitVector ```julia @satvariable(a, BitVector, 16) diff --git a/src/BooleanOperations.jl b/src/BooleanOperations.jl index dbcf89d..91a413e 100644 --- a/src/BooleanOperations.jl +++ b/src/BooleanOperations.jl @@ -223,22 +223,21 @@ end """ - ite(x::BoolExpr, y::BoolExpr, z::BoolExpr) + ite(x::BoolExpr, y::AbstractExpr, z::AbstractExpr) -If-then-else statement. Equivalent to `or(x ∧ y, ¬x ∧ z)`. +If-then-else statement. When x, y, and z are Bool, equivalent to `or(x ∧ y, ¬x ∧ z)`. Note that `y` and `z` may be other expression types. For example, given the variables `BoolExpr z` and `IntExpr a`, Satisfiability.jl rewrites `z + a` as `ite(z, 1, 0) + a`. """ -function ite(x::Union{BoolExpr, Bool}, y::Union{BoolExpr, Bool}, z::Union{BoolExpr, Bool}) +function ite(x::BoolExpr, y::T, z::T) where T <: AbstractExpr zs = [x, y, z] - if any(isa.(zs, Bool)) # if any of these is a literal - return or(and(x, y), and(not(x), z)) # this will simplify it correctly + if isa(x, Bool) # if x is literal + return x ? y : z end - value = any(isnothing.([x.value, y.value, z.value])) ? nothing : (x.value & y.value) | (!(x.value) & z.value) + value = any(isnothing.([x.value, y.value, z.value])) ? nothing : x ? y : z return BoolExpr(:ite, zs, value, __get_hash_name(:ite, zs)) end - ##### SUPPORT FOR OPERATIONS WITH MIXED LITERALS ##### not(z::Bool) = !z @@ -270,6 +269,11 @@ iff(z1::BoolExpr, z2::Bool) = z2 ? z1 : ¬z1 # if z2 is true z1 must be true and iff(z1::Bool, z2::BoolExpr) = z1 ? z2 : ¬z2 iff(z1::Bool, z2::Bool) = z1 == z2 +ite(x::Bool, y::Any, z::Any) = x ? y : z +ite(x::BoolExpr, y::Any, z::T) where T <: AbstractExpr = ite(x, __wrap_const(y), z) +ite(x::BoolExpr, y::T, z::Any) where T <: AbstractExpr = ite(x, y, __wrap_const(z)) +ite(x::BoolExpr, y::T, z::T) where T <: Any = ite(x, __wrap_const(y), __wrap_const(z)) + """ value(z::BoolExpr) value(z::Array{BoolExpr}) diff --git a/src/IntExpr.jl b/src/IntExpr.jl index 30a12d4..f7591bd 100644 --- a/src/IntExpr.jl +++ b/src/IntExpr.jl @@ -1,4 +1,4 @@ -import Base.<, Base.<=, Base.>, Base.<=, Base.+, Base.-, Base.*, Base./, Base.==, Base.!= +import Base.<, Base.<=, Base.>, Base.<=, Base.+, Base.-, Base.*, Base./, Base.div, Base.==, Base.!=, Base.promote_rule, Base.convert abstract type NumericExpr <: AbstractExpr end @@ -70,17 +70,17 @@ function RealExpr(name::String) :: RealExpr end -# These are necessary for defining interoperability between IntExpr, RealExpr, BoolExpr and built-in types such as Int, Bool, and Float. +# These are necessary for defining interoperability between IntExpr, RealExpr, and built-in types such as Int, Bool, and Float. NumericInteroperableExpr = Union{NumericExpr, BoolExpr} NumericInteroperableConst = Union{Bool, Int, Float64} -NumericInteroperable = Union{NumericInteroperableExpr, NumericInteroperableConst} __wrap_const(c::Float64) = RealExpr(:const, AbstractExpr[], c, c >= 0 ? "const_$c" : "const_neg_$(abs(c))") -__wrap_const(c::Union{Int, Bool}) = IntExpr(:const, AbstractExpr[], c, c >= 0 ? "const_$c" : "const_neg_$(abs(c))") # prevents names like -1 from being generated, which are disallowed in SMT-LIB +__wrap_const(c::Int) = IntExpr(:const, AbstractExpr[], c, c >= 0 ? "const_$c" : "const_neg_$(abs(c))") # prevents names like -1 from being generated, which are disallowed in SMT-LIB +__wrap_const(c::Bool) = BoolExpr(:const, AbstractExpr[], c, "const_$c") ##### COMPARISON OPERATIONS #### -# These return Boolean values. In the SMT dialect we would say they have sort Bool +# These return Boolean values, eg they have sort Bool # See figure 3.3 in the SMT-LIB standard. """ a < b @@ -96,7 +96,7 @@ a .< b a .< z ``` """ -function Base.:<(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) +function Base.:<(e1::T, e2::T) where T <: NumericInteroperableExpr value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value < e2.value name = __get_hash_name(:lt, [e1, e2]) return BoolExpr(:lt, [e1, e2], value, name) @@ -116,7 +116,7 @@ a .<= b a .<= z ``` """ -function Base.:<=(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) +function Base.:<=(e1::T, e2::T) where T <: NumericInteroperableExpr value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value <= e2.value name = __get_hash_name(:leq, [e1, e2]) return BoolExpr(:leq, [e1, e2], value, name) @@ -136,7 +136,7 @@ a .>= b a .>= z ``` """ -function Base.:>=(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) +function Base.:>=(e1::T, e2::T) where T <: NumericInteroperableExpr value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value >= e2.value name = __get_hash_name(:geq, [e1, e2]) return BoolExpr(:geq, [e1, e2], value, name) @@ -156,7 +156,7 @@ a .> b a .> z ``` """ -function Base.:>(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) +function Base.:>(e1::T, e2::T) where T <: NumericInteroperableExpr value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value > e2.value name = __get_hash_name(:gt, [e1, e2]) return BoolExpr(:gt, [e1, e2], value, name) @@ -181,7 +181,7 @@ a .== b **Note:** To test whether two `AbstractExpr`s are eqivalent (in the sense that all properties are equal, not in the shared-memory-location sense of `===`), use `isequal`. """ -function Base.:(==)(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) +function Base.:(==)(e1::T, e2::T) where T <: NumericInteroperableExpr value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value == e2.value name = __get_hash_name(:eq, [e1, e2], is_commutative=true) return BoolExpr(:eq, [e1, e2], value, name, __is_commutative=true) @@ -203,7 +203,7 @@ isequal( ) ```` """ -function distinct(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) +function distinct(e1::T, e2::T) where T <: NumericInteroperableExpr value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value != e2.value name = __get_hash_name(:distinct, [e1, e2], is_commutative=true) return BoolExpr(:distinct, [e1, e2], value, name, __is_commutative=true) @@ -220,19 +220,27 @@ distinct(es::Base.Generator) = distinct(collect(es)) # INTEROPERABILITY FOR COMPARISON OPERATIONS +Base.:>(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = Base.:>(promote(e1, e2)...) Base.:>(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = e1 > __wrap_const(e2) Base.:>(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __wrap_const(e1) > e2 + +Base.:>=(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = Base.:>=(promote(e1, e2)...) Base.:>=(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = e1 >= __wrap_const(e2) Base.:>=(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __wrap_const(e1) >= e2 +Base.:<(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = Base.:<(promote(e1, e2)...) Base.:<(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = e1 < __wrap_const(e2) Base.:<(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __wrap_const(e1) < e2 + +Base.:<=(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = Base.:<=(promote(e1, e2)...) Base.:<=(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = e1 <= __wrap_const(e2) Base.:<=(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __wrap_const(e1) <= e2 +Base.:(==)(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = Base.:(==)(promote(e1, e2)...) Base.:(==)(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = e1 == __wrap_const(e2) Base.:(==)(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __wrap_const(e1) == e2 +distinct(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = distinct(promote(e1, e2)...) distinct(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = distinct(e1, __wrap_const(e2)) distinct(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = distinct(__wrap_const(e1), e2) distinct(e1::NumericInteroperableConst, e2::NumericInteroperableConst) = e1 != e2 @@ -279,7 +287,7 @@ function __merge_const!(es::Array{T}) where T <: AbstractExpr end end -# This works for any n_ary op that takes as input NumericInteroperable arguments +# This works for any n_ary op that takes as input NumericInteroperableExpr arguments function __numeric_n_ary_op(es_mixed::Array, op::Symbol; __is_commutative=false, __try_flatten=false) # clean up types! This guarantees es::Array{AbstractExpr} es, literals = __check_inputs_nary_op(es_mixed, const_type=NumericInteroperableConst, expr_type=NumericInteroperableExpr) @@ -291,7 +299,7 @@ function __numeric_n_ary_op(es_mixed::Array, op::Symbol; __is_commutative=false, end # Determine return expr type. Note that / promotes to RealExpr because the SMT theory of integers doesn't include it - ReturnType = any(isa.(es, RealExpr)) || op == :div ? RealExpr : IntExpr + ReturnType = any(isa.(es, RealExpr)) ? RealExpr : IntExpr children, name = __combine(es, op, __is_commutative, __try_flatten) # Now it is possible we have several CONST exprs. This occurs if, for example, one writes (a+1) + (b+1) which flattens to a+1+b+1 @@ -333,9 +341,9 @@ println("typeof a+z: \$(typeof(a[1] + z))") ``` """ -Base.:+(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :add, __is_commutative=true, __try_flatten=true) -Base.:+(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :add, __is_commutative=true, __try_flatten=true) -Base.:+(e1::Union{NumericInteroperableConst}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :add, __is_commutative=true, __try_flatten=true) +Base.:+(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op(collect(promote(e1, e2)), :add, __is_commutative=true, __try_flatten=true) +Base.:+(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :add, __is_commutative=true, __try_flatten=true) +Base.:+(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :add, __is_commutative=true, __try_flatten=true) """ a - b @@ -357,9 +365,9 @@ a .- z println("typeof a-z: \$(typeof(a[1] - z))") ``` """ -Base.:-(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :sub) -Base.:-(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :sub) -Base.:-(e1::Union{NumericInteroperableConst}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :sub) +Base.:-(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op(collect(promote(e1, e2)), :sub) +Base.:-(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :sub) +Base.:-(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :sub) """ a * b @@ -381,15 +389,34 @@ a .- z println("typeof a*z: \$(typeof(a[1]*z))") ``` """ -Base.:*(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :mul, __is_commutative=true, __try_flatten=true) -Base.:*(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :mul, __is_commutative=true, __try_flatten=true) -Base.:*(e1::Union{NumericInteroperableConst}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :mul, __is_commutative=true, __try_flatten=true) +Base.:*(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op(collect(promote(e1, e2)), :mul, __is_commutative=true, __try_flatten=true) +Base.:*(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :mul, __is_commutative=true, __try_flatten=true) +Base.:*(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :mul, __is_commutative=true, __try_flatten=true) + + +""" + div(a, b) + div(a, 2) + +Returns the `Int` division expression `div(a,b)`. Note: `a` and `b` will be converted to `IntExpr`). Use dot broadcasting for vector-valued and matrix-valued Boolean expressions. + +```julia +@satvariable(a[1:n], Int) +@satvariable(b[1:n, 1:m], Int) +div.(a, b) +println("typeof div(a,b): \$(typeof(div(a[1],b[1])))") +``` +""" +Base.div(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op([convert(IntExpr, e1), convert(IntExpr, e2)], :div) +Base.div(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :div) +Base.div(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :div) + """ a / b - a / 1.0 + a / 2.0 -Returns the `Real` division expression `a/b`. Note: `a` and `b` must be `Real`). Use dot broadcasting for vector-valued and matrix-valued Boolean expressions. +Returns the `Real` division expression `a/b`. Note: `a` and `b` will be converted to `RealExpr`). Use dot broadcasting for vector-valued and matrix-valued Boolean expressions. ```julia @satvariable(a[1:n], Real) @@ -398,6 +425,32 @@ a ./ b println("typeof a/b: \$(typeof(a[1]/b[1]))") ``` """ -Base.:/(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :div) -Base.:/(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :div) -Base.:/(e1::Union{NumericInteroperableConst}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :div) \ No newline at end of file +Base.:/(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op([convert(RealExpr, e1), convert(RealExpr, e2)], :rdiv) +Base.:/(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :rdiv) +Base.:/(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :rdiv) + + +""" + to_real(a::IntExpr) + +Performs manual conversion of an IntExpr to a RealExpr. Note that Satisfiability.jl automatically promotes types in arithmetic and comparison expressions, so this function is usually unnecessary to explicitly call. +""" +to_real(a::IntExpr) = RealExpr(:to_real, [a], isnothing(a.value) ? nothing : Float64(a.value), __get_hash_name(:to_real, [a])) + +""" + to_int(a::RealExpr) + +Performs manual conversion of a RealExpr to an IntExpr. +""" +to_int(a::RealExpr) = IntExpr(:to_int, [a], isnothing(a.value) ? nothing : Int(floor(a.value)), __get_hash_name(:to_int, [a])) + +##### PROMOTION RULES ##### +# These govern the promotion of BoolExpr, IntExpr and RealExpr types. +Base.promote_rule(::Type{IntExpr}, ::Type{BoolExpr}) = IntExpr +Base.promote_rule(::Type{RealExpr}, ::Type{BoolExpr}) = RealExpr +Base.promote_rule(::Type{RealExpr}, ::Type{IntExpr}) = RealExpr + + +Base.convert(::Type{IntExpr}, z::BoolExpr) = z.op == :const ? __wrap_const(Int64(z.value)) : ite(z, 1, 0) +Base.convert(::Type{RealExpr}, z::BoolExpr) = z.op == :const ? __wrap_const(Float64(z.value)) : ite(z, 1.0, 0.0) +Base.convert(::Type{RealExpr}, a::IntExpr) = a.op == :const ? __wrap_const(Float64(a.value)) : RealExpr(:to_real, [a], isnothing(a.value) ? nothing : Float64(a.value), __get_hash_name(:to_real, [a])) diff --git a/src/Satisfiability.jl b/src/Satisfiability.jl index 5b5f3eb..1ff19fc 100644 --- a/src/Satisfiability.jl +++ b/src/Satisfiability.jl @@ -11,7 +11,9 @@ export AbstractExpr, BitVectorExpr, isequal, hash, # required by isequal (?) - in # specialize to use isequal instead of == + in, # specialize to use isequal instead of == + promote_rule, + convert export and, ∧, @@ -27,7 +29,9 @@ export ==, <, <=, >, >=, distinct export - +, -, *, / + +, -, *, /, + to_real, + to_int # BitVector specific functions export diff --git a/src/sat.jl b/src/sat.jl index 1574370..c26d55f 100644 --- a/src/sat.jl +++ b/src/sat.jl @@ -269,7 +269,7 @@ __julia_symbolic_ops = Dict( :add => +, :sub => -, :mul => *, - :div => /, + :rdiv => /, :neg => -, :lt => <, :leq => <=, diff --git a/src/smt_representation.jl b/src/smt_representation.jl index d14990a..1a1ea4d 100644 --- a/src/smt_representation.jl +++ b/src/smt_representation.jl @@ -26,7 +26,7 @@ __smt_symbolic_ops = Dict( :add => "+", :sub => "-", :mul => "*", - :div => "/", + :rdiv => "/", :neg => "-", :lt => "<", :leq => "<=", @@ -44,7 +44,7 @@ __smt_generated_ops = Dict( # Finally, we provide facilities for correct encoding of consts function __format_smt_const(exprtype::Type, c::AbstractExpr) # there's no such thing as a Bool const because all Bool consts are simplifiable - if exprtype <: IntExpr || exprtype <: RealExpr + if exprtype <: IntExpr || exprtype <: RealExpr || exprtype <: BoolExpr return string(c.value) # automatically does the right thing for Ints and Reals elseif exprtype <: AbstractBitVectorExpr if c.length % 4 == 0 # can be a hex string diff --git a/src/utilities.jl b/src/utilities.jl index f74dd38..193d31c 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -39,10 +39,10 @@ function __combine(zs::Array{T}, op::Symbol, __is_commutative=false, __try_flatt # if this is an op where it makes sense to flatten (eg, and(and(x,y), and(y,z)) then flatten it) ops = getproperty.(zs, :op) if __try_flatten && (all(ops .== op) || - (__is_commutative && all(map( (o) -> o in [:identity, :const, op], ops)))) + (__is_commutative && all(map( (o) -> o in [:identity, :const, :to_real, :to_int, op], ops)))) # Returm a combined operator # this line merges childless operators and children, eg and(x, and(y,z)) yields [x, y, z] - children = cat(map( (e) -> length(e.children) > 0 ? e.children : [e], zs)..., dims=1) + children = cat(map( (e) -> length(e.children) == 0 || e.op ∈ [:to_real, :to_int] ? [e] : e.children, zs)..., dims=1) else # op doesn't match, so we won't flatten it children = zs end diff --git a/test/boolean_operation_tests.jl b/test/boolean_operation_tests.jl index 60e5cf2..6000b46 100644 --- a/test/boolean_operation_tests.jl +++ b/test/boolean_operation_tests.jl @@ -184,6 +184,6 @@ end @test all(isequal.(iff.(z, A), iff.(A, z))) y = @satvariable(y[1:1], Bool) - @test all( isequal.(ite.(z, true, false), or.(and.(z, true), and.(¬z, false)) )) - @test all( isequal.(ite.(false, y, z), or.(and.(false, y), and.(true, z)) )) + @test all(isequal.(ite.(true, z, y), z )) + @test all(isequal.(ite.(false, true, y), y )) end \ No newline at end of file diff --git a/test/int_real_tests.jl b/test/int_real_tests.jl index 0ee7633..f3b1df9 100644 --- a/test/int_real_tests.jl +++ b/test/int_real_tests.jl @@ -64,7 +64,7 @@ end @test isequal(sum([1.0, a, true, 1]), RealExpr(:add, children, nothing, Satisfiability.__get_hash_name(:add, children, is_commutative=true))) # Type promotion to RealExpr works when we add a real-valued expr - children = [a, b[1], IntExpr(:const, AbstractExpr[], 2, "const_2.0")] + children = [to_real(a), to_real(b[1]), RealExpr(:const, AbstractExpr[], 2.0, "const_2.0")] @test isequal(sum([a, 1.0, 1, false, b[1]]), RealExpr(:add, children, nothing, Satisfiability.__get_hash_name(:add, children, is_commutative=true))) # Sum works automatically diff --git a/test/output_parse_tests.jl b/test/output_parse_tests.jl index e8979f9..270cda6 100644 --- a/test/output_parse_tests.jl +++ b/test/output_parse_tests.jl @@ -86,8 +86,8 @@ end b = a @satvariable(a, Real) hashname = Satisfiability.__get_hash_name(:add, [b, a], is_commutative=true) - @test smt(b+a, assert=false) == "(declare-fun a () Int) -(declare-fun a () Real) -(define-fun $hashname () Real (+ (as a Int) (as a Real))) + @test smt(b+a, assert=false) == "(declare-fun a () Real) +(declare-fun a () Int) +(define-fun $hashname () Real (+ (as a Real) (to_real (as a Int)))) " end \ No newline at end of file diff --git a/test/solver_interface_tests.jl b/test/solver_interface_tests.jl index a7a580c..9ec4349 100644 --- a/test/solver_interface_tests.jl +++ b/test/solver_interface_tests.jl @@ -1,4 +1,4 @@ -push!(LOAD_PATH, "../src") +push!(LOAD_PATH, "./src") using Satisfiability using Test, Logging @@ -80,7 +80,7 @@ using Test, Logging values = Dict("ar2_1"=>1., "ar2_2"=>2.) @satvariable(ar2[1:2], Real) - test_expr = RealExpr(:div, ar2, nothing, "test") + test_expr = RealExpr(:rdiv, ar2, nothing, "test") assign!(test_expr, values) @test value(test_expr) == (1. / 2.) From 54266436bec4ba428c4d9971aff23177401e69bd Mon Sep 17 00:00:00 2001 From: Emiko Soroka Date: Tue, 28 Nov 2023 21:15:28 -0800 Subject: [PATCH 2/7] All tests pass. (1) Expanded ite(x,y,z) to allow non-Boolean values for y and z and simplified code. (2) Added promote_type and convert rules to codify promotion rules for IntExpr, BoolExpr and Real Expr. (3) Added integer division and cleaned up type conversion around integer and real-valued division. (4) Added to_real and to_int SMT-LIB functions. (5) Fixed issue#21 using (2) and (4) to ensure correct promotion by wrapping promoted BoolExprs using ite and wrapping promoted IntExprs using to_real. --- docs/src/functions.md | 9 ++- src/BooleanOperations.jl | 18 +++--- src/IntExpr.jl | 107 ++++++++++++++++++++++++-------- src/Satisfiability.jl | 8 ++- src/sat.jl | 2 +- src/smt_representation.jl | 4 +- src/utilities.jl | 4 +- test/boolean_operation_tests.jl | 4 +- test/int_real_tests.jl | 2 +- test/output_parse_tests.jl | 6 +- test/solver_interface_tests.jl | 8 ++- 11 files changed, 122 insertions(+), 50 deletions(-) diff --git a/docs/src/functions.md b/docs/src/functions.md index 23e7542..31a6774 100644 --- a/docs/src/functions.md +++ b/docs/src/functions.md @@ -32,7 +32,7 @@ distinct(z1::BoolExpr, z2::BoolExpr) ## Arithmetic operations These are operations in the theory of integer and real-valued arithmetic. -Note that `+`, `-`, and `*` follow type promotion rules: if both `a` and `b` are `IntExpr`s, `a+b` will have type `IntExpr`. If either `a` or `b` is a `RealExpr`, the result will have type `RealExpr`. Division `\` is defined only in the theory of real-valued arithmetic, thus it always has return type `RealExpr`. +Note that `+`, `-`, and `*` follow type promotion rules: if both `a` and `b` are `IntExpr`s, `a+b` will have type `IntExpr`. If either `a` or `b` is a `RealExpr`, the result will have type `RealExpr`. Integer division `div(a,b)` is defined only for `IntExpr`s. Real-valued division `a\b` is defined only in the theory of real-valued arithmetic. For a formal definition of the theory of integer arithmetic, see Figure 3.3 in *The SMT-LIB Standard, Version 2.6*. ```@docs @@ -40,6 +40,7 @@ Base.:-(a::IntExpr) Base.:+(a::IntExpr, b::IntExpr) Base.:-(a::IntExpr, b::IntExpr) Base.:*(a::RealExpr, b::RealExpr) +Base.div(a::IntExpr, b::IntExpr) Base.:/(a::RealExpr, b::RealExpr) ``` @@ -57,6 +58,12 @@ Base.:>(a::IntExpr, b::IntExpr) Base.:>=(a::IntExpr, b::IntExpr) ``` +### Conversion operators +```@docs +to_int(a::RealExpr) +to_real(a::IntExpr) +``` + ## BitVector ```julia @satvariable(a, BitVector, 16) diff --git a/src/BooleanOperations.jl b/src/BooleanOperations.jl index dbcf89d..91a413e 100644 --- a/src/BooleanOperations.jl +++ b/src/BooleanOperations.jl @@ -223,22 +223,21 @@ end """ - ite(x::BoolExpr, y::BoolExpr, z::BoolExpr) + ite(x::BoolExpr, y::AbstractExpr, z::AbstractExpr) -If-then-else statement. Equivalent to `or(x ∧ y, ¬x ∧ z)`. +If-then-else statement. When x, y, and z are Bool, equivalent to `or(x ∧ y, ¬x ∧ z)`. Note that `y` and `z` may be other expression types. For example, given the variables `BoolExpr z` and `IntExpr a`, Satisfiability.jl rewrites `z + a` as `ite(z, 1, 0) + a`. """ -function ite(x::Union{BoolExpr, Bool}, y::Union{BoolExpr, Bool}, z::Union{BoolExpr, Bool}) +function ite(x::BoolExpr, y::T, z::T) where T <: AbstractExpr zs = [x, y, z] - if any(isa.(zs, Bool)) # if any of these is a literal - return or(and(x, y), and(not(x), z)) # this will simplify it correctly + if isa(x, Bool) # if x is literal + return x ? y : z end - value = any(isnothing.([x.value, y.value, z.value])) ? nothing : (x.value & y.value) | (!(x.value) & z.value) + value = any(isnothing.([x.value, y.value, z.value])) ? nothing : x ? y : z return BoolExpr(:ite, zs, value, __get_hash_name(:ite, zs)) end - ##### SUPPORT FOR OPERATIONS WITH MIXED LITERALS ##### not(z::Bool) = !z @@ -270,6 +269,11 @@ iff(z1::BoolExpr, z2::Bool) = z2 ? z1 : ¬z1 # if z2 is true z1 must be true and iff(z1::Bool, z2::BoolExpr) = z1 ? z2 : ¬z2 iff(z1::Bool, z2::Bool) = z1 == z2 +ite(x::Bool, y::Any, z::Any) = x ? y : z +ite(x::BoolExpr, y::Any, z::T) where T <: AbstractExpr = ite(x, __wrap_const(y), z) +ite(x::BoolExpr, y::T, z::Any) where T <: AbstractExpr = ite(x, y, __wrap_const(z)) +ite(x::BoolExpr, y::T, z::T) where T <: Any = ite(x, __wrap_const(y), __wrap_const(z)) + """ value(z::BoolExpr) value(z::Array{BoolExpr}) diff --git a/src/IntExpr.jl b/src/IntExpr.jl index 30a12d4..f7591bd 100644 --- a/src/IntExpr.jl +++ b/src/IntExpr.jl @@ -1,4 +1,4 @@ -import Base.<, Base.<=, Base.>, Base.<=, Base.+, Base.-, Base.*, Base./, Base.==, Base.!= +import Base.<, Base.<=, Base.>, Base.<=, Base.+, Base.-, Base.*, Base./, Base.div, Base.==, Base.!=, Base.promote_rule, Base.convert abstract type NumericExpr <: AbstractExpr end @@ -70,17 +70,17 @@ function RealExpr(name::String) :: RealExpr end -# These are necessary for defining interoperability between IntExpr, RealExpr, BoolExpr and built-in types such as Int, Bool, and Float. +# These are necessary for defining interoperability between IntExpr, RealExpr, and built-in types such as Int, Bool, and Float. NumericInteroperableExpr = Union{NumericExpr, BoolExpr} NumericInteroperableConst = Union{Bool, Int, Float64} -NumericInteroperable = Union{NumericInteroperableExpr, NumericInteroperableConst} __wrap_const(c::Float64) = RealExpr(:const, AbstractExpr[], c, c >= 0 ? "const_$c" : "const_neg_$(abs(c))") -__wrap_const(c::Union{Int, Bool}) = IntExpr(:const, AbstractExpr[], c, c >= 0 ? "const_$c" : "const_neg_$(abs(c))") # prevents names like -1 from being generated, which are disallowed in SMT-LIB +__wrap_const(c::Int) = IntExpr(:const, AbstractExpr[], c, c >= 0 ? "const_$c" : "const_neg_$(abs(c))") # prevents names like -1 from being generated, which are disallowed in SMT-LIB +__wrap_const(c::Bool) = BoolExpr(:const, AbstractExpr[], c, "const_$c") ##### COMPARISON OPERATIONS #### -# These return Boolean values. In the SMT dialect we would say they have sort Bool +# These return Boolean values, eg they have sort Bool # See figure 3.3 in the SMT-LIB standard. """ a < b @@ -96,7 +96,7 @@ a .< b a .< z ``` """ -function Base.:<(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) +function Base.:<(e1::T, e2::T) where T <: NumericInteroperableExpr value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value < e2.value name = __get_hash_name(:lt, [e1, e2]) return BoolExpr(:lt, [e1, e2], value, name) @@ -116,7 +116,7 @@ a .<= b a .<= z ``` """ -function Base.:<=(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) +function Base.:<=(e1::T, e2::T) where T <: NumericInteroperableExpr value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value <= e2.value name = __get_hash_name(:leq, [e1, e2]) return BoolExpr(:leq, [e1, e2], value, name) @@ -136,7 +136,7 @@ a .>= b a .>= z ``` """ -function Base.:>=(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) +function Base.:>=(e1::T, e2::T) where T <: NumericInteroperableExpr value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value >= e2.value name = __get_hash_name(:geq, [e1, e2]) return BoolExpr(:geq, [e1, e2], value, name) @@ -156,7 +156,7 @@ a .> b a .> z ``` """ -function Base.:>(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) +function Base.:>(e1::T, e2::T) where T <: NumericInteroperableExpr value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value > e2.value name = __get_hash_name(:gt, [e1, e2]) return BoolExpr(:gt, [e1, e2], value, name) @@ -181,7 +181,7 @@ a .== b **Note:** To test whether two `AbstractExpr`s are eqivalent (in the sense that all properties are equal, not in the shared-memory-location sense of `===`), use `isequal`. """ -function Base.:(==)(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) +function Base.:(==)(e1::T, e2::T) where T <: NumericInteroperableExpr value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value == e2.value name = __get_hash_name(:eq, [e1, e2], is_commutative=true) return BoolExpr(:eq, [e1, e2], value, name, __is_commutative=true) @@ -203,7 +203,7 @@ isequal( ) ```` """ -function distinct(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) +function distinct(e1::T, e2::T) where T <: NumericInteroperableExpr value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value != e2.value name = __get_hash_name(:distinct, [e1, e2], is_commutative=true) return BoolExpr(:distinct, [e1, e2], value, name, __is_commutative=true) @@ -220,19 +220,27 @@ distinct(es::Base.Generator) = distinct(collect(es)) # INTEROPERABILITY FOR COMPARISON OPERATIONS +Base.:>(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = Base.:>(promote(e1, e2)...) Base.:>(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = e1 > __wrap_const(e2) Base.:>(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __wrap_const(e1) > e2 + +Base.:>=(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = Base.:>=(promote(e1, e2)...) Base.:>=(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = e1 >= __wrap_const(e2) Base.:>=(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __wrap_const(e1) >= e2 +Base.:<(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = Base.:<(promote(e1, e2)...) Base.:<(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = e1 < __wrap_const(e2) Base.:<(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __wrap_const(e1) < e2 + +Base.:<=(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = Base.:<=(promote(e1, e2)...) Base.:<=(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = e1 <= __wrap_const(e2) Base.:<=(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __wrap_const(e1) <= e2 +Base.:(==)(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = Base.:(==)(promote(e1, e2)...) Base.:(==)(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = e1 == __wrap_const(e2) Base.:(==)(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __wrap_const(e1) == e2 +distinct(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = distinct(promote(e1, e2)...) distinct(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = distinct(e1, __wrap_const(e2)) distinct(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = distinct(__wrap_const(e1), e2) distinct(e1::NumericInteroperableConst, e2::NumericInteroperableConst) = e1 != e2 @@ -279,7 +287,7 @@ function __merge_const!(es::Array{T}) where T <: AbstractExpr end end -# This works for any n_ary op that takes as input NumericInteroperable arguments +# This works for any n_ary op that takes as input NumericInteroperableExpr arguments function __numeric_n_ary_op(es_mixed::Array, op::Symbol; __is_commutative=false, __try_flatten=false) # clean up types! This guarantees es::Array{AbstractExpr} es, literals = __check_inputs_nary_op(es_mixed, const_type=NumericInteroperableConst, expr_type=NumericInteroperableExpr) @@ -291,7 +299,7 @@ function __numeric_n_ary_op(es_mixed::Array, op::Symbol; __is_commutative=false, end # Determine return expr type. Note that / promotes to RealExpr because the SMT theory of integers doesn't include it - ReturnType = any(isa.(es, RealExpr)) || op == :div ? RealExpr : IntExpr + ReturnType = any(isa.(es, RealExpr)) ? RealExpr : IntExpr children, name = __combine(es, op, __is_commutative, __try_flatten) # Now it is possible we have several CONST exprs. This occurs if, for example, one writes (a+1) + (b+1) which flattens to a+1+b+1 @@ -333,9 +341,9 @@ println("typeof a+z: \$(typeof(a[1] + z))") ``` """ -Base.:+(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :add, __is_commutative=true, __try_flatten=true) -Base.:+(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :add, __is_commutative=true, __try_flatten=true) -Base.:+(e1::Union{NumericInteroperableConst}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :add, __is_commutative=true, __try_flatten=true) +Base.:+(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op(collect(promote(e1, e2)), :add, __is_commutative=true, __try_flatten=true) +Base.:+(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :add, __is_commutative=true, __try_flatten=true) +Base.:+(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :add, __is_commutative=true, __try_flatten=true) """ a - b @@ -357,9 +365,9 @@ a .- z println("typeof a-z: \$(typeof(a[1] - z))") ``` """ -Base.:-(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :sub) -Base.:-(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :sub) -Base.:-(e1::Union{NumericInteroperableConst}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :sub) +Base.:-(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op(collect(promote(e1, e2)), :sub) +Base.:-(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :sub) +Base.:-(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :sub) """ a * b @@ -381,15 +389,34 @@ a .- z println("typeof a*z: \$(typeof(a[1]*z))") ``` """ -Base.:*(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :mul, __is_commutative=true, __try_flatten=true) -Base.:*(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :mul, __is_commutative=true, __try_flatten=true) -Base.:*(e1::Union{NumericInteroperableConst}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :mul, __is_commutative=true, __try_flatten=true) +Base.:*(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op(collect(promote(e1, e2)), :mul, __is_commutative=true, __try_flatten=true) +Base.:*(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :mul, __is_commutative=true, __try_flatten=true) +Base.:*(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :mul, __is_commutative=true, __try_flatten=true) + + +""" + div(a, b) + div(a, 2) + +Returns the `Int` division expression `div(a,b)`. Note: `a` and `b` will be converted to `IntExpr`). Use dot broadcasting for vector-valued and matrix-valued Boolean expressions. + +```julia +@satvariable(a[1:n], Int) +@satvariable(b[1:n, 1:m], Int) +div.(a, b) +println("typeof div(a,b): \$(typeof(div(a[1],b[1])))") +``` +""" +Base.div(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op([convert(IntExpr, e1), convert(IntExpr, e2)], :div) +Base.div(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :div) +Base.div(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :div) + """ a / b - a / 1.0 + a / 2.0 -Returns the `Real` division expression `a/b`. Note: `a` and `b` must be `Real`). Use dot broadcasting for vector-valued and matrix-valued Boolean expressions. +Returns the `Real` division expression `a/b`. Note: `a` and `b` will be converted to `RealExpr`). Use dot broadcasting for vector-valued and matrix-valued Boolean expressions. ```julia @satvariable(a[1:n], Real) @@ -398,6 +425,32 @@ a ./ b println("typeof a/b: \$(typeof(a[1]/b[1]))") ``` """ -Base.:/(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :div) -Base.:/(e1::Union{NumericInteroperableExpr}, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :div) -Base.:/(e1::Union{NumericInteroperableConst}, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :div) \ No newline at end of file +Base.:/(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op([convert(RealExpr, e1), convert(RealExpr, e2)], :rdiv) +Base.:/(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :rdiv) +Base.:/(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :rdiv) + + +""" + to_real(a::IntExpr) + +Performs manual conversion of an IntExpr to a RealExpr. Note that Satisfiability.jl automatically promotes types in arithmetic and comparison expressions, so this function is usually unnecessary to explicitly call. +""" +to_real(a::IntExpr) = RealExpr(:to_real, [a], isnothing(a.value) ? nothing : Float64(a.value), __get_hash_name(:to_real, [a])) + +""" + to_int(a::RealExpr) + +Performs manual conversion of a RealExpr to an IntExpr. +""" +to_int(a::RealExpr) = IntExpr(:to_int, [a], isnothing(a.value) ? nothing : Int(floor(a.value)), __get_hash_name(:to_int, [a])) + +##### PROMOTION RULES ##### +# These govern the promotion of BoolExpr, IntExpr and RealExpr types. +Base.promote_rule(::Type{IntExpr}, ::Type{BoolExpr}) = IntExpr +Base.promote_rule(::Type{RealExpr}, ::Type{BoolExpr}) = RealExpr +Base.promote_rule(::Type{RealExpr}, ::Type{IntExpr}) = RealExpr + + +Base.convert(::Type{IntExpr}, z::BoolExpr) = z.op == :const ? __wrap_const(Int64(z.value)) : ite(z, 1, 0) +Base.convert(::Type{RealExpr}, z::BoolExpr) = z.op == :const ? __wrap_const(Float64(z.value)) : ite(z, 1.0, 0.0) +Base.convert(::Type{RealExpr}, a::IntExpr) = a.op == :const ? __wrap_const(Float64(a.value)) : RealExpr(:to_real, [a], isnothing(a.value) ? nothing : Float64(a.value), __get_hash_name(:to_real, [a])) diff --git a/src/Satisfiability.jl b/src/Satisfiability.jl index 5b5f3eb..1ff19fc 100644 --- a/src/Satisfiability.jl +++ b/src/Satisfiability.jl @@ -11,7 +11,9 @@ export AbstractExpr, BitVectorExpr, isequal, hash, # required by isequal (?) - in # specialize to use isequal instead of == + in, # specialize to use isequal instead of == + promote_rule, + convert export and, ∧, @@ -27,7 +29,9 @@ export ==, <, <=, >, >=, distinct export - +, -, *, / + +, -, *, /, + to_real, + to_int # BitVector specific functions export diff --git a/src/sat.jl b/src/sat.jl index 1574370..c26d55f 100644 --- a/src/sat.jl +++ b/src/sat.jl @@ -269,7 +269,7 @@ __julia_symbolic_ops = Dict( :add => +, :sub => -, :mul => *, - :div => /, + :rdiv => /, :neg => -, :lt => <, :leq => <=, diff --git a/src/smt_representation.jl b/src/smt_representation.jl index d14990a..1a1ea4d 100644 --- a/src/smt_representation.jl +++ b/src/smt_representation.jl @@ -26,7 +26,7 @@ __smt_symbolic_ops = Dict( :add => "+", :sub => "-", :mul => "*", - :div => "/", + :rdiv => "/", :neg => "-", :lt => "<", :leq => "<=", @@ -44,7 +44,7 @@ __smt_generated_ops = Dict( # Finally, we provide facilities for correct encoding of consts function __format_smt_const(exprtype::Type, c::AbstractExpr) # there's no such thing as a Bool const because all Bool consts are simplifiable - if exprtype <: IntExpr || exprtype <: RealExpr + if exprtype <: IntExpr || exprtype <: RealExpr || exprtype <: BoolExpr return string(c.value) # automatically does the right thing for Ints and Reals elseif exprtype <: AbstractBitVectorExpr if c.length % 4 == 0 # can be a hex string diff --git a/src/utilities.jl b/src/utilities.jl index f74dd38..193d31c 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -39,10 +39,10 @@ function __combine(zs::Array{T}, op::Symbol, __is_commutative=false, __try_flatt # if this is an op where it makes sense to flatten (eg, and(and(x,y), and(y,z)) then flatten it) ops = getproperty.(zs, :op) if __try_flatten && (all(ops .== op) || - (__is_commutative && all(map( (o) -> o in [:identity, :const, op], ops)))) + (__is_commutative && all(map( (o) -> o in [:identity, :const, :to_real, :to_int, op], ops)))) # Returm a combined operator # this line merges childless operators and children, eg and(x, and(y,z)) yields [x, y, z] - children = cat(map( (e) -> length(e.children) > 0 ? e.children : [e], zs)..., dims=1) + children = cat(map( (e) -> length(e.children) == 0 || e.op ∈ [:to_real, :to_int] ? [e] : e.children, zs)..., dims=1) else # op doesn't match, so we won't flatten it children = zs end diff --git a/test/boolean_operation_tests.jl b/test/boolean_operation_tests.jl index 60e5cf2..6000b46 100644 --- a/test/boolean_operation_tests.jl +++ b/test/boolean_operation_tests.jl @@ -184,6 +184,6 @@ end @test all(isequal.(iff.(z, A), iff.(A, z))) y = @satvariable(y[1:1], Bool) - @test all( isequal.(ite.(z, true, false), or.(and.(z, true), and.(¬z, false)) )) - @test all( isequal.(ite.(false, y, z), or.(and.(false, y), and.(true, z)) )) + @test all(isequal.(ite.(true, z, y), z )) + @test all(isequal.(ite.(false, true, y), y )) end \ No newline at end of file diff --git a/test/int_real_tests.jl b/test/int_real_tests.jl index 0ee7633..f3b1df9 100644 --- a/test/int_real_tests.jl +++ b/test/int_real_tests.jl @@ -64,7 +64,7 @@ end @test isequal(sum([1.0, a, true, 1]), RealExpr(:add, children, nothing, Satisfiability.__get_hash_name(:add, children, is_commutative=true))) # Type promotion to RealExpr works when we add a real-valued expr - children = [a, b[1], IntExpr(:const, AbstractExpr[], 2, "const_2.0")] + children = [to_real(a), to_real(b[1]), RealExpr(:const, AbstractExpr[], 2.0, "const_2.0")] @test isequal(sum([a, 1.0, 1, false, b[1]]), RealExpr(:add, children, nothing, Satisfiability.__get_hash_name(:add, children, is_commutative=true))) # Sum works automatically diff --git a/test/output_parse_tests.jl b/test/output_parse_tests.jl index e8979f9..270cda6 100644 --- a/test/output_parse_tests.jl +++ b/test/output_parse_tests.jl @@ -86,8 +86,8 @@ end b = a @satvariable(a, Real) hashname = Satisfiability.__get_hash_name(:add, [b, a], is_commutative=true) - @test smt(b+a, assert=false) == "(declare-fun a () Int) -(declare-fun a () Real) -(define-fun $hashname () Real (+ (as a Int) (as a Real))) + @test smt(b+a, assert=false) == "(declare-fun a () Real) +(declare-fun a () Int) +(define-fun $hashname () Real (+ (as a Real) (to_real (as a Int)))) " end \ No newline at end of file diff --git a/test/solver_interface_tests.jl b/test/solver_interface_tests.jl index a7a580c..9afdf6e 100644 --- a/test/solver_interface_tests.jl +++ b/test/solver_interface_tests.jl @@ -1,4 +1,4 @@ -push!(LOAD_PATH, "../src") +push!(LOAD_PATH, "./src") using Satisfiability using Test, Logging @@ -78,9 +78,13 @@ using Test, Logging assign!(test_expr, values) @test value(test_expr) == -1 + test_expr.op = :div; test_expr.children = a3[2:3] + assign!(test_expr, values) + @test value(test_expr) == 0 + values = Dict("ar2_1"=>1., "ar2_2"=>2.) @satvariable(ar2[1:2], Real) - test_expr = RealExpr(:div, ar2, nothing, "test") + test_expr = RealExpr(:rdiv, ar2, nothing, "test") assign!(test_expr, values) @test value(test_expr) == (1. / 2.) From 6cd59ae3d44c87c8bddaedd6b22b00c3c4d9541f Mon Sep 17 00:00:00 2001 From: Emiko Soroka Date: Wed, 29 Nov 2023 11:28:22 -0800 Subject: [PATCH 3/7] finished work from yesterday, added unittest coverage. Added abs and mod operations for IntExprs. --- src/BooleanOperations.jl | 18 ++++++------- src/IntExpr.jl | 46 ++++++++++++++++++++++++---------- test/int_real_tests.jl | 10 ++++++++ test/solver_interface_tests.jl | 11 +++++++- 4 files changed, 62 insertions(+), 23 deletions(-) diff --git a/src/BooleanOperations.jl b/src/BooleanOperations.jl index 91a413e..2a60456 100644 --- a/src/BooleanOperations.jl +++ b/src/BooleanOperations.jl @@ -227,14 +227,17 @@ end If-then-else statement. When x, y, and z are Bool, equivalent to `or(x ∧ y, ¬x ∧ z)`. Note that `y` and `z` may be other expression types. For example, given the variables `BoolExpr z` and `IntExpr a`, Satisfiability.jl rewrites `z + a` as `ite(z, 1, 0) + a`. """ -function ite(x::BoolExpr, y::T, z::T) where T <: AbstractExpr - zs = [x, y, z] - if isa(x, Bool) # if x is literal - return x ? y : z +function ite(x::BoolExpr, y::Any, z::Any) + if !isa(y, AbstractExpr) + y = __wrap_const(y) end - + if !isa(z, AbstractExpr) + z = __wrap_const(z) + end + (y,z) = promote(y,z) + zs = AbstractExpr[x, y, z] value = any(isnothing.([x.value, y.value, z.value])) ? nothing : x ? y : z - return BoolExpr(:ite, zs, value, __get_hash_name(:ite, zs)) + return typeof(y)(:ite, zs, value, __get_hash_name(:ite, zs)) end @@ -270,9 +273,6 @@ iff(z1::Bool, z2::BoolExpr) = z1 ? z2 : ¬z2 iff(z1::Bool, z2::Bool) = z1 == z2 ite(x::Bool, y::Any, z::Any) = x ? y : z -ite(x::BoolExpr, y::Any, z::T) where T <: AbstractExpr = ite(x, __wrap_const(y), z) -ite(x::BoolExpr, y::T, z::Any) where T <: AbstractExpr = ite(x, y, __wrap_const(z)) -ite(x::BoolExpr, y::T, z::T) where T <: Any = ite(x, __wrap_const(y), __wrap_const(z)) """ value(z::BoolExpr) diff --git a/src/IntExpr.jl b/src/IntExpr.jl index f7591bd..f365bc9 100644 --- a/src/IntExpr.jl +++ b/src/IntExpr.jl @@ -1,4 +1,4 @@ -import Base.<, Base.<=, Base.>, Base.<=, Base.+, Base.-, Base.*, Base./, Base.div, Base.==, Base.!=, Base.promote_rule, Base.convert +import Base.<, Base.<=, Base.>, Base.<=, Base.+, Base.-, Base.*, Base./, Base.div, Base.mod, Base.abs, Base.==, Base.!=, Base.promote_rule, Base.convert abstract type NumericExpr <: AbstractExpr end @@ -265,6 +265,16 @@ Base.:-(e::RealExpr) = RealExpr(:neg, RealExpr[e,], isnothing(e.value) ? nothing # Define array version for convenience because the syntax .- for unary operators is confusing. Base.:-(es::Array{T}) where T <: NumericExpr = .-es +""" + abs(a::IntExpr) + +Return the absolute value of an `IntExpr`. + +When called on a `RealExpr`, `abs(a::RealExpr)` returns `ite(a >= 0, a, -a)`. This design decision was made because Z3 allows `abs` to be called on a real-valued expression and returns that result, but `abs` is only defined in the SMT-LIB standard for integer variables. Thus, users may call `abs` on real-valued expressions. +""" +Base.abs(e::IntExpr) = IntExpr(:abs, IntExpr[e,], isnothing(e.value) ? nothing : -e.value, __get_hash_name(:abs, [e,])) +Base.abs(e::RealExpr) = ite(e >= 0.0, e, -e) +Base.abs(e::BoolExpr) = ite(e, 1, 0) ##### COMBINING OPERATIONS ##### # These return Int values. We would say they have sort Int. @@ -323,7 +333,7 @@ end a + b a + 1 + true -Return the `Int` | `Real` expression `a+b` (inherits the type of `a+b`). Use dot broadcasting for vector-valued and matrix-valued Boolean expressions. +Return the `Int` | `Real` expression `a+b` (inherits the type of `a+b`). Use dot broadcasting for vector-valued and matrix-valued expressions. ```julia @@ -349,7 +359,7 @@ Base.:+(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric a - b a - 2 -Returns the `Int` | `Real` expression `a-b` (inherits the type of `a-b`). Use dot broadcasting for vector-valued and matrix-valued Boolean expressions. +Returns the `Int` | `Real` expression `a-b` (inherits the type of `a-b`). Use dot broadcasting for vector-valued and matrix-valued expressions. ```julia @satvariable(a[1:n], Int) @@ -373,7 +383,7 @@ Base.:-(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric a * b a * 2 -Returns the `Int` | `Real` multiplication expression `a*b` (inherits the type of `a*b`). Use dot broadcasting for vector-valued and matrix-valued Boolean expressions. +Returns the `Int` | `Real` multiplication expression `a*b` (inherits the type of `a*b`). Use dot broadcasting for vector-valued and matrix-valued expressions. ```julia @satvariable(a[1:n], Int) @@ -398,7 +408,7 @@ Base.:*(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric div(a, b) div(a, 2) -Returns the `Int` division expression `div(a,b)`. Note: `a` and `b` will be converted to `IntExpr`). Use dot broadcasting for vector-valued and matrix-valued Boolean expressions. +Returns the `Int` division expression `div(a,b)`. Note: `a` and `b` will be converted to `IntExpr`). Use dot broadcasting for vector-valued and matrix-valued expressions. ```julia @satvariable(a[1:n], Int) @@ -408,15 +418,24 @@ println("typeof div(a,b): \$(typeof(div(a[1],b[1])))") ``` """ Base.div(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op([convert(IntExpr, e1), convert(IntExpr, e2)], :div) -Base.div(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :div) -Base.div(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :div) +Base.div(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([convert(IntExpr, e1), __wrap_const(Int(floor(e2)))], :div) +Base.div(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([__wrap_const(Int(floor(e1))), convert(IntExpr, e2)], :div) +""" + mod(a, b) + mod(a, 2) + +Returns the `Int` modulus expression `mod(a,b)`. Note: `a` and `b` will be converted to `IntExpr`). Use dot broadcasting for vector-valued and matrix-valued expressions. +""" +Base.mod(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op([convert(IntExpr, e1), convert(IntExpr, e2)], :mod) +Base.mod(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([convert(IntExpr, e1), __wrap_const(Int(floor(e2)))], :mod) +Base.mod(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([__wrap_const(Int(floor(e1))), convert(IntExpr, e2)], :mod) """ a / b a / 2.0 -Returns the `Real` division expression `a/b`. Note: `a` and `b` will be converted to `RealExpr`). Use dot broadcasting for vector-valued and matrix-valued Boolean expressions. +Returns the `Real` division expression `a/b`. Note: `a` and `b` will be converted to `RealExpr`). Use dot broadcasting for vector-valued and matrix-valued expressions. ```julia @satvariable(a[1:n], Real) @@ -426,8 +445,8 @@ println("typeof a/b: \$(typeof(a[1]/b[1]))") ``` """ Base.:/(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op([convert(RealExpr, e1), convert(RealExpr, e2)], :rdiv) -Base.:/(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([e1, e2], :rdiv) -Base.:/(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([e1, e2], :rdiv) +Base.:/(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([convert(RealExpr, e1), __wrap_const(Float64(e2))], :rdiv) +Base.:/(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric_n_ary_op([__wrap_const(Float64(e1)), convert(RealExpr, e2)], :rdiv) """ @@ -440,7 +459,7 @@ to_real(a::IntExpr) = RealExpr(:to_real, [a], isnothing(a.value) ? nothing : Flo """ to_int(a::RealExpr) -Performs manual conversion of a RealExpr to an IntExpr. +Performs manual conversion of a RealExpr to an IntExpr. Equivalent to Julia `Int(floor(a))`. """ to_int(a::RealExpr) = IntExpr(:to_int, [a], isnothing(a.value) ? nothing : Int(floor(a.value)), __get_hash_name(:to_int, [a])) @@ -451,6 +470,7 @@ Base.promote_rule(::Type{RealExpr}, ::Type{BoolExpr}) = RealExpr Base.promote_rule(::Type{RealExpr}, ::Type{IntExpr}) = RealExpr -Base.convert(::Type{IntExpr}, z::BoolExpr) = z.op == :const ? __wrap_const(Int64(z.value)) : ite(z, 1, 0) +Base.convert(::Type{IntExpr}, z::BoolExpr) = z.op == :const ? __wrap_const(Int64(z.value)) : ite(z, 1, 0) Base.convert(::Type{RealExpr}, z::BoolExpr) = z.op == :const ? __wrap_const(Float64(z.value)) : ite(z, 1.0, 0.0) -Base.convert(::Type{RealExpr}, a::IntExpr) = a.op == :const ? __wrap_const(Float64(a.value)) : RealExpr(:to_real, [a], isnothing(a.value) ? nothing : Float64(a.value), __get_hash_name(:to_real, [a])) +Base.convert(::Type{RealExpr}, a::IntExpr) = a.op == :const ? __wrap_const(Float64(a.value)) : to_real(a) +Base.convert(::Type{IntExpr}, a::RealExpr) = a.op == :const ? __wrap_const(Int64(a.value)) : to_int(a) diff --git a/test/int_real_tests.jl b/test/int_real_tests.jl index f3b1df9..ed9f699 100644 --- a/test/int_real_tests.jl +++ b/test/int_real_tests.jl @@ -72,4 +72,14 @@ end @test all(isequal.((a - 3).children, [a, IntExpr(:const, AbstractExpr[], 3, "const_3")])) @test all(isequal.((ar/3.0).children, [ar, RealExpr(:const, AbstractExpr[], 3., "const_3.0")])) + + # div, /, mod type coercion + @test isequal(div(2.0, ar), div(2, to_int(ar))) + @test isequal(a/2, to_real(a)/2.0) + @test isequal(mod(ar, 2.0), mod(to_int(ar), 2)) + + # abs rewrites to ite for non-int variables + @satvariable(z, Bool) + @test isequal(abs(z), ite(z, 1, 0)) + @test isequal(abs(ar), ite(ar >= 0.0, ar, -ar)) end \ No newline at end of file diff --git a/test/solver_interface_tests.jl b/test/solver_interface_tests.jl index 9afdf6e..455f82e 100644 --- a/test/solver_interface_tests.jl +++ b/test/solver_interface_tests.jl @@ -80,7 +80,16 @@ using Test, Logging test_expr.op = :div; test_expr.children = a3[2:3] assign!(test_expr, values) - @test value(test_expr) == 0 + @test value(test_expr) == div(2,3) + + test_expr.op = :mod; test_expr.children = a3[2:3] + assign!(test_expr, values) + @test value(test_expr) == mod(2,3) + + values = Dict("a3_1"=>1, "a3_2"=>-2, "a3_3"=>3) + test_expr.op = :abs; test_expr.children = a3[2:2] + assign!(test_expr, values) + @test value(test_expr) == 2 && value(a3[2]) == -2 values = Dict("ar2_1"=>1., "ar2_2"=>2.) @satvariable(ar2[1:2], Real) From 44c3e5c1b85b0c0eceb9097eb464ac3029be0183 Mon Sep 17 00:00:00 2001 From: Emiko Soroka Date: Wed, 29 Nov 2023 11:31:36 -0800 Subject: [PATCH 4/7] Include docstrings for abs and mod. All tests pass. This compliance check against the SMT-LIB Ints and Reals theory definition is complete. --- docs/src/functions.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/functions.md b/docs/src/functions.md index 31a6774..15c2bb3 100644 --- a/docs/src/functions.md +++ b/docs/src/functions.md @@ -37,10 +37,12 @@ For a formal definition of the theory of integer arithmetic, see Figure 3.3 in * ```@docs Base.:-(a::IntExpr) +Base.abs(a::IntExpr) Base.:+(a::IntExpr, b::IntExpr) Base.:-(a::IntExpr, b::IntExpr) Base.:*(a::RealExpr, b::RealExpr) Base.div(a::IntExpr, b::IntExpr) +Base.mod(a::IntExpr, b::IntExpr) Base.:/(a::RealExpr, b::RealExpr) ``` From f621d42688bf76ed54c0b1d7521b4c8154d48230 Mon Sep 17 00:00:00 2001 From: Emiko Soroka Date: Thu, 30 Nov 2023 12:12:19 -0800 Subject: [PATCH 5/7] Added unittest coverage for remaining BitVector operations: rotate_left, rotate_right, zero_extend, sign_extend, repeat, bvcomp. Added to docs. --- docs/src/functions.md | 43 +++++++------ src/BitVectorExpr.jl | 130 +++++++++++++++++++++++++++++++++++--- src/IntExpr.jl | 1 - src/Satisfiability.jl | 10 ++- src/smt_representation.jl | 11 +++- test/bitvector_tests.jl | 43 ++++++++++++- 6 files changed, 199 insertions(+), 39 deletions(-) diff --git a/docs/src/functions.md b/docs/src/functions.md index 15c2bb3..9016e24 100644 --- a/docs/src/functions.md +++ b/docs/src/functions.md @@ -77,9 +77,10 @@ The SMT-LIB standard BitVector is often used to represent operations on fixed-si ### Bitwise operators In addition to supporting the comparison operators above and arithmetic operators `+`, `-`, and `*`, the following BitVector-specific operators are available. -Note that unsigned integer division is available using `div`. +Note that unsigned integer division is available using `div`. Signed division is `sdiv`. ```@docs Base.div(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +sdiv(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) ``` The bitwise logical operator symbols `&`, `~` and `|` are provided for BitVector types instead of the Boolean logic symbols. This matches Julia's use of bitwise logical operators for Unsigned integer types. @@ -88,17 +89,36 @@ The bitwise logical operator symbols `&`, `~` and `|` are provided for BitVector Base.:~(a::BitVectorExpr{UInt8}) Base.:|(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) Base.:&(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +nor(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +nand(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +xnor(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) Base.:<<(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +Base.:>>(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) Base.:>>>(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) urem(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +srem(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +smod(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +``` +Signed comparisons. +```@docs +slt(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +sle(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +sgt(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +sge(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) ``` -The following word-level operations are also available in the SMT-LIB standard. +The following word-level operations are also available in the SMT-LIB standard, either as core operations or defined in the [SMT-LIB BitVector logic](https://smtlib.cs.uiowa.edu/logics-all.shtml#QF_BV). ```@docs concat(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +repeat(a::BitVectorExpr{UInt8}, n::Int64) Base.getindex(a::BitVectorExpr{UInt8}, ind::UnitRange{Int64}) bv2int(a::BitVectorExpr{UInt8}) int2bv(a::IntExpr, s::Int) +bvcomp(a::BitVectorExpr{UInt8}, BitVectorExpr{UInt8}) +zero_extend(a::BitVectorExpr{UInt8}, n::Int64) +sign_extend(a::BitVectorExpr{UInt8}, n::Int64) +rotate_left(a::BitVectorExpr{UInt8}, n::Int64) +rotate_right(a::BitVectorExpr{UInt8}, n::Int64) ``` ### Utility functions for BitVectors @@ -108,25 +128,6 @@ nextsize(n::Integer) bvconst(c::Integer, size::Int) ``` -### Additional Z3 BitVector operators. -Z3 implements the following signed comparisons for BitVectors. Note that these are not part of the SMT-LIB standard and other solvers may not support them. -```@docs -Base.:>>(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) -srem(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) -smod(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) -nor(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) -nand(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) -xnor(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) -``` - -Signed comparisons are also Z3-specific. -```@docs -slt(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) -sle(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) -sgt(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) -sge(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) -``` - ## Generating the SMT representation of a problem ```@docs diff --git a/src/BitVectorExpr.jl b/src/BitVectorExpr.jl index 1ce013c..149617d 100644 --- a/src/BitVectorExpr.jl +++ b/src/BitVectorExpr.jl @@ -1,5 +1,5 @@ -import Base.getindex, Base.setproperty! -import Base.+, Base.-, Base.*, Base.<<, Base.>>, Base.>>>, Base.div, Base.&, Base.|, Base.~ +import Base.getindex, Base.setproperty!, Base.length +import Base.+, Base.-, Base.*, Base.<<, Base.>>, Base.>>>, Base.div, Base.&, Base.|, Base.~, Base.repeat if VERSION.minor >= 0x07 import Base.nor, Base.nand end @@ -59,6 +59,8 @@ end # some utility functions +Base.length(e::AbstractBitVectorExpr) = e.length + """" nextsize(n::Integer) @@ -146,6 +148,13 @@ Unsigned integer division of two BitVectors. """ div(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop(div, :bvudiv, BitVectorExpr, [e1, e2]) +""" + sdiv(a::BitVectorExpr, b::BitVectorExpr) + +Signed integer division of two BitVectors. +""" +sdiv(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop(div, :bvsdiv, BitVectorExpr, [e1, e2]) + # unary minus, this is an arithmetic minus not a bit flip. -(e::BitVectorExpr) = __bv1op(e, -, :bvneg) @@ -346,8 +355,20 @@ sge(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop(__signf ##### Word-level operations ##### -# concat and extract are the only SMT-LIB standard operations -# z3 adds some more, note that concat can accept constants and has arity n >= 2 +# see https://smtlib.cs.uiowa.edu/logics-all.shtml#QF_BV +""" + bvcomp(a, b) + bvcomp(a, bvconst(a, 0xffff, 16)) + +Bitwise comparator: iff all bits of `a` and `b` are equal, `bvcomp(a,b) = 0b1`, otherwise `0b0`. +""" +function bvcomp(a::AbstractBitVectorExpr, b::AbstractBitVectorExpr) + ReturnIntType = nextsize(1) + return BitVectorExpr{ReturnIntType}(:bvcomp, [a,b], bvcomp(a.value, b.value), __get_hash_name(:bvcomp, [a,b]), true) +end + +bvcomp(a::Integer, b::Integer) = a == b ? 0b1 : 0b0 + """ concat(a, b) concat(a, bvconst(0xffff, 16), b, bvconst(0x01, 8), ...) @@ -364,12 +385,12 @@ Arguments are concatenated such that the first argument to concat corresponds to println(expr.value) # 0x01023 ``` """ -function concat(es_mixed::Vararg{Any}) - es_mixed = collect(es_mixed) +function concat(es_mixed::Array{T}) where T <: Any vars, consts = __check_inputs_nary_op(es_mixed, const_type=Integer, expr_type=BitVectorExpr) # only consts if isnothing(vars) || length(vars)==0 - return concat(consts) + lengths = map(bitcount, consts) + return __concat(consts, lengths, nextsize(sum(lengths))) end # preserve order of inputs @@ -387,6 +408,8 @@ function concat(es_mixed::Vararg{Any}) return BitVectorExpr{ReturnType}(:concat, collect(es), value, name, l) end +concat(es_mixed::Vararg{Any}) = concat(collect(es_mixed)) + # for constant values function concat(vals::Array{T}) where T <: Integer lengths = map(bitcount, vals) @@ -402,6 +425,14 @@ function __concat(vals::Array{T}, bitsizes::Array{R}, ReturnType::Type) where T return value end +""" + repeat(a::BitVectorExpr, n) + repeat(bvconst(0xffff, 16), n) + +Repeat bitvector `a` `n` times. +""" +repeat(a::AbstractBitVectorExpr, n::Int64) = concat([a for i=1:n]) +repeat(a::Integer, n::Int64) = concat([a for i=1:n]) ##### INDEXING ##### # SMT-LIB indexing is called extract and works in a slightly weird manner @@ -415,7 +446,7 @@ Base.getindex(e::AbstractBitVectorExpr, ind::Int64) = getindex(e, ind, ind) a[4:8] # has length 5 a[3] -Slice or index into a BitVector, returning a new BitVector with the appropriate length. This corresponds to the SMT-LIB operation `extract`. +Slice or index into a `BitVectorExpr`, returning a new `BitVectorExpr` with the appropriate length. This corresponds to the SMT-LIB operation `extract`. """ function Base.getindex(e::AbstractBitVectorExpr, ind::UnitRange{Int64}) if first(ind) > last(ind) || first(ind) < 1 || last(ind) > e.length @@ -431,6 +462,75 @@ function Base.getindex(e::AbstractBitVectorExpr, ind::UnitRange{Int64}) end +##### Extension and rotation ##### +""" + zero_extend(a::BitVectorExpr, n::Int) + +Pad `BitVectorExpr` `a` with zeros. `n` specifies the number of bits and must be nonnegative. +""" +function zero_extend(e::AbstractBitVectorExpr, n::Int64) + if n < 0 + error("n must be nonnegative for zero_extend!") + end + ReturnIntType = nextsize(length(e) + n) + v = isnothing(e.value) ? nothing : zero_extend(e.value, length(e), n) + return SlicedBitVectorExpr{ReturnIntType}(:zero_extend, [e], v, __get_hash_name(:zero_extend, [e]), length(e) + n, false, n) +end +function zero_extend(v::T, val_len::Int64, n::Int64) where T <: Integer + ReturnIntType = nextsize(val_len + n) + return ReturnIntType(v) +end + +""" + sign_extend(a::BitVectorExpr, n::Int) + +Pad `BitVectorExpr` `a` with 0 or 1 depending on its sign. `n` specifies the number of bits and must be nonnegative. +""" +function sign_extend(e::AbstractBitVectorExpr, n::Int64) + if n < 0 + error("n must be nonnegative for sign_extend!") + end + ReturnIntType = nextsize(length(e) + n) + v = isnothing(e.value) ? nothing : sign_extend(e.value, length(e), n) + return SlicedBitVectorExpr{ReturnIntType}(:sign_extend, [e], v, __get_hash_name(:sign_extend, [e]), length(e) + n, false, n) +end +function sign_extend(v::T, val_len::Int64, n::Int64) where T <: Integer + ReturnIntType = nextsize(val_len + n) + pad = signed(v) > 0 ? ReturnIntType(0) : (typemax(ReturnIntType) << val_len) + return ReturnIntType(v) | pad +end + +""" + rotate_left(a::BitVectorExpr, n::Int) + +Rotate `BitVectorExpr` `a` by n bits left. `n` must be nonnegative. +""" +function rotate_left(e::AbstractBitVectorExpr, n::Int64) + if n < 0 + error("n must be nonnegative for rotate_left!") + end + ReturnIntType = typeof(e).parameters[1] + v = isnothing(e.value) ? nothing : bitrotate(e.value, n) # bitrotate goes left + return SlicedBitVectorExpr{ReturnIntType}(:rotate_left, [e], v, __get_hash_name(:rotate_left, [e]), length(e), false, n) +end +rotate_left(v::T, n::Int64) where T <: Integer = bitrotate(v, n) + +""" + rotate_right(a::BitVectorExpr, n::Int) + +Rotate `BitVectorExpr` `a` by n bits right. `n` must be nonnegative. +""" +function rotate_right(e::AbstractBitVectorExpr, n::Int64) + if n < 0 + error("n must be nonnegative for rotate_right!") + end + ReturnIntType = typeof(e).parameters[1] + v = isnothing(e.value) ? nothing : bitrotate(e.value, -n) # bitrotate goes left, so -n is a right rotation + return SlicedBitVectorExpr{ReturnIntType}(:rotate_right, [e], v, __get_hash_name(:rotate_right, [e]), length(e), false, n) +end +rotate_right(v::T, n::Int64) where T <: Integer = bitrotate(v, -n) + + ##### Translation to/from integer ##### # Be aware these have high overhead """ @@ -495,7 +595,7 @@ end # Constants may be specified in base 10 as long as they are explicitly constructed to be of type Unsigned or BigInt. # Examples: 0xDEADBEEF (UInt32), 0b0101 (UInt8), 0o7700 (UInt16), big"123456789012345678901234567890" (BigInt) # Consts can be padded, so for example you can add 0x01 (UInt8) to (_ BitVec 16) -# Variables cannot be padded! For example, 0x0101 (Uint16) cannot be added to (_ BitVec 8). +# Variables cannot be implicitly padded! For example, 0x0101 (Uint16) cannot be added to (_ BitVec 8). To add these, use sign_extend or zero_extend. __2ops = [:+, :-, :*, :/, :<, :<=, :>, :>=, :(==), :!=, :sle, :slt, :sge, :sgt, :nand, :nor, :<<, :>>, :>>>, :&, :|, :~, :srem, :urem, :smod] @@ -514,6 +614,7 @@ end __bitvector_const_ops = Dict( :bvudiv => div, + :bvsdiv => __signfix(div), # TODO check :bvshl => (<<), :bvlshr => (>>>), :bvashr => __signfix(>>), @@ -538,6 +639,7 @@ __bitvector_const_ops = Dict( :bvsle => __signfix(<=), :bvsgt => __signfix(>=), :bvsge => __signfix(>), + :bvcomp => (a,b) -> a == b ? 0b1 : 0b0, # see https://smtlib.cs.uiowa.edu/logics-all.shtml#QF_BV :eq => (==) ) @@ -551,15 +653,25 @@ function __propagate_value!(z::AbstractBitVectorExpr) if z.op == :concat ls = getproperty.(z.children, :length) z.value = __concat(vs, ls, nextsize(z.length)) + elseif z.op == :int2bv z.value = nextsize(z.length)(z.children[1].value) + elseif z.op == :extract ReturnIntType = typeof(z).parameters[1] v = z.children[1].value z.value = v & ReturnIntType(reduce(|, map((i) -> 2^(i-1), z.range))) + + elseif z.op in [:rotate_left, :rotate_right] + z.value = eval(z.op)(z.children[1].value, z.range) + + elseif z.op in [:zero_extend, :sign_extend] + z.value = eval(z.op)(z.children[1].value, length(z.children[1]), z.range) + elseif z.op ∈ keys(__bitvector_const_ops) op = __bitvector_const_ops[z.op] z.value = length(vs)>1 ? op(vs...) : op(vs[1]) + else op = eval(z.op) z.value = length(vs)>1 ? op(vs...) : op(vs[1]) diff --git a/src/IntExpr.jl b/src/IntExpr.jl index f365bc9..84e7dd8 100644 --- a/src/IntExpr.jl +++ b/src/IntExpr.jl @@ -469,7 +469,6 @@ Base.promote_rule(::Type{IntExpr}, ::Type{BoolExpr}) = IntExpr Base.promote_rule(::Type{RealExpr}, ::Type{BoolExpr}) = RealExpr Base.promote_rule(::Type{RealExpr}, ::Type{IntExpr}) = RealExpr - Base.convert(::Type{IntExpr}, z::BoolExpr) = z.op == :const ? __wrap_const(Int64(z.value)) : ite(z, 1, 0) Base.convert(::Type{RealExpr}, z::BoolExpr) = z.op == :const ? __wrap_const(Float64(z.value)) : ite(z, 1.0, 0.0) Base.convert(::Type{RealExpr}, a::IntExpr) = a.op == :const ? __wrap_const(Float64(a.value)) : to_real(a) diff --git a/src/Satisfiability.jl b/src/Satisfiability.jl index 1ff19fc..2a94af7 100644 --- a/src/Satisfiability.jl +++ b/src/Satisfiability.jl @@ -30,6 +30,7 @@ export distinct export +, -, *, /, + abs, mod, div, to_real, to_int @@ -37,20 +38,23 @@ export export nextsize, bitcount, - div, - urem, + bvcomp, <<, >>, >>>, &, |, ~, + urem, srem, smod, + sdiv, nor, ⊽, nand, ⊼, xnor, slt, sle, sgt, sge, - concat, + concat, repeat, + zero_extend, sign_extend, + rotate_left, rotate_right, bv2int, int2bv, bvconst diff --git a/src/smt_representation.jl b/src/smt_representation.jl index 1a1ea4d..11c5666 100644 --- a/src/smt_representation.jl +++ b/src/smt_representation.jl @@ -36,9 +36,14 @@ __smt_symbolic_ops = Dict( # These are extra-special cases where the operator name is not ASCII and has to be generated at runtime __smt_generated_ops = Dict( - :int2bv => (e::AbstractBitVectorExpr) -> "(_ int2bv $(e.length))", - :extract => (e::AbstractBitVectorExpr) -> "(_ extract $(last(e.range)-1) $(first(e.range)-1))", - :ufunc => (e::AbstractExpr) -> split(e.name, "_")[1] + :int2bv => (e::AbstractBitVectorExpr) -> "(_ int2bv $(e.length))", + :extract => (e::SlicedBitVectorExpr) -> "(_ extract $(last(e.range)-1) $(first(e.range)-1))", + :repeat => (e::SlicedBitVectorExpr) -> "(_ repeat $(e.range))", + :zero_extend => (e::SlicedBitVectorExpr) -> "(_ zero_extend $(e.range))", + :sign_extend => (e::SlicedBitVectorExpr) -> "(_ sign_extend $(e.range))", + :rotate_left => (e::SlicedBitVectorExpr) -> "(_ rotate_left $(e.range))", + :rotate_right => (e::SlicedBitVectorExpr) -> "(_ rotate_right $(e.range))", + :ufunc => (e::AbstractExpr) -> split(e.name, "_")[1] ) # Finally, we provide facilities for correct encoding of consts diff --git a/test/bitvector_tests.jl b/test/bitvector_tests.jl index e20a156..917d9d8 100644 --- a/test/bitvector_tests.jl +++ b/test/bitvector_tests.jl @@ -22,8 +22,8 @@ CLEAR_VARNAMES!() # unary minus @test (-d).op == :bvneg # combining ops - ops = [+, -, *, div, urem, <<, >>, srem, smod, >>>, nor, nand, xnor] - names = [:bvadd, :bvsub, :bvmul, :bvudiv, :bvurem, :bvshl, :bvashr, :bvsrem, :bvsmod, :bvlshr, :bvnor, :bvnand, :bvxnor] + ops = [+, -, *, div, sdiv, urem, <<, >>, srem, smod, >>>, nor, nand, xnor] + names = [:bvadd, :bvsub, :bvmul, :bvudiv, :bvsdiv, :bvurem, :bvshl, :bvashr, :bvsrem, :bvsmod, :bvlshr, :bvnor, :bvnand, :bvxnor] for (op, name) in zip(ops, names) @test isequal(op(a,b), BitVectorExpr{UInt16}(name, [a,b], nothing, Satisfiability.__get_hash_name(name, [a,b]), 16)) end @@ -148,4 +148,43 @@ end @test a.value == 0xff @test b.value == 0x00 +end + +@testset "Assigning values" begin + assign! = Satisfiability.assign! + @satvariable(a, BitVector, 8) + @satvariable(b, BitVector, 8) + values = Dict("a" => 0x01, "b" => 0xf0) + + expr = a | b; assign!(expr, values) + @test expr.value == 0xf1 + + expr = -a - b; assign!(expr, values) + @test expr.value == -0xf1 + + expr = div(b,a); assign!(expr, values) + @test expr.value == 0xf0 + + expr = sdiv(-b,a); assign!(expr, values) + @test expr.value == div(-0xf0, 0x01) + + expr = repeat(a, 3); assign!(expr, values) + @test expr.value == 0x010101 + + expr = zero_extend(a, 4); assign!(expr, values) + @test expr.value == 0x0001 + + expr = sign_extend(-a, 4); assign!(expr, values) + @test expr.value == 0xffff + + expr = rotate_left(b, 4); assign!(expr, values) + @test expr.value == 0x0f + + expr = rotate_right(b, 4); assign!(expr, values) + @test expr.value == 0x0f + + expr = bvcomp(a,a); assign!(expr, values) + @test expr.value == 0b1 + expr = bvcomp(a,b); assign!(expr, values) + @test expr.value == 0b0 end \ No newline at end of file From 41539fb408eb5d3a08ddd3eb1e286fc7d39ca52d Mon Sep 17 00:00:00 2001 From: Emiko Soroka Date: Thu, 30 Nov 2023 14:30:48 -0800 Subject: [PATCH 6/7] Documentation cleanup, verified no missing docstrings. --- docs/src/functions.md | 14 +++++++------- src/BitVectorExpr.jl | 26 +++++++++++++------------- src/IntExpr.jl | 6 +++--- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/docs/src/functions.md b/docs/src/functions.md index 9016e24..33b7ce2 100644 --- a/docs/src/functions.md +++ b/docs/src/functions.md @@ -16,7 +16,7 @@ An **uninterpreted function** is a function where the mapping between input and ## Logical operations -These are operations in the theory of propositional logic. For a formal definition of this theory, see Figure 3.2 in *The SMT-LIB Standard, Version 2.6* or the SMT-LIB [Core theory declaration](http://smtlib.cs.uiowa.edu/theories.shtml). +These are operations in the theory of propositional logic. For a formal definition of this theory, see Figure 3.2 in [*The SMT-LIB Standard, Version 2.6*](https://smtlib.cs.uiowa.edu/standard.shtml) or the SMT-LIB [Core theory declaration](http://smtlib.cs.uiowa.edu/theories.shtml). ```@docs not(z::BoolExpr) and(z1::BoolExpr, z2::BoolExpr) @@ -25,7 +25,7 @@ xor(zs_mixed::Array{T}; broadcast_type=:Elementwise) where T implies(z1::BoolExpr, z2::BoolExpr) iff(z1::BoolExpr, z2::BoolExpr) -ite(x::Union{BoolExpr, Bool}, y::Union{BoolExpr, Bool}, z::Union{BoolExpr, Bool}) +ite(x::BoolExpr, y::BoolExpr, z::BoolExpr) distinct(z1::BoolExpr, z2::BoolExpr) ``` @@ -41,8 +41,8 @@ Base.abs(a::IntExpr) Base.:+(a::IntExpr, b::IntExpr) Base.:-(a::IntExpr, b::IntExpr) Base.:*(a::RealExpr, b::RealExpr) -Base.div(a::IntExpr, b::IntExpr) -Base.mod(a::IntExpr, b::IntExpr) +div(a::IntExpr, b::IntExpr) +mod(a::IntExpr, b::IntExpr) Base.:/(a::RealExpr, b::RealExpr) ``` @@ -109,12 +109,12 @@ sge(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) The following word-level operations are also available in the SMT-LIB standard, either as core operations or defined in the [SMT-LIB BitVector logic](https://smtlib.cs.uiowa.edu/logics-all.shtml#QF_BV). ```@docs -concat(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) +concat(a::Array{T}) where T repeat(a::BitVectorExpr{UInt8}, n::Int64) Base.getindex(a::BitVectorExpr{UInt8}, ind::UnitRange{Int64}) bv2int(a::BitVectorExpr{UInt8}) int2bv(a::IntExpr, s::Int) -bvcomp(a::BitVectorExpr{UInt8}, BitVectorExpr{UInt8}) +bvcomp(a::BitVectorExpr{UInt8}, b::BitVectorExpr{UInt8}) zero_extend(a::BitVectorExpr{UInt8}, n::Int64) sign_extend(a::BitVectorExpr{UInt8}, n::Int64) rotate_left(a::BitVectorExpr{UInt8}, n::Int64) @@ -132,7 +132,7 @@ bvconst(c::Integer, size::Int) ```@docs smt(zs::Array{T}) where T <: BoolExpr -save(prob::BoolExpr, io::IO) +save(prob::BoolExpr) ``` ## Solving a SAT problem diff --git a/src/BitVectorExpr.jl b/src/BitVectorExpr.jl index 149617d..f65da39 100644 --- a/src/BitVectorExpr.jl +++ b/src/BitVectorExpr.jl @@ -61,7 +61,7 @@ end # some utility functions Base.length(e::AbstractBitVectorExpr) = e.length -"""" +""" nextsize(n::Integer) Returns the smallest unsigned integer type that can store a number with n bits. @@ -80,7 +80,7 @@ function nextsize(n::Integer) # works on BigInt and UInt end end -"""" +""" bitcount(a::Integer) Returns the minimum number of bits required to store the number `a`. @@ -198,21 +198,21 @@ Logical right shift a >>> b. """ srem(a::BitVectorExpr, b::BitVectorExpr) -Signed remainder of BitVector a divided by BitVector b. This operator is not part of the SMT-LIB standard BitVector theory: it is implemented by Z3. It may not be available when using other solvers. +Signed remainder of BitVector a divided by BitVector b. """ srem(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop(__signfix(rem), :bvsrem, BitVectorExpr, [e1, e2]) # unique to z3 """ smod(a::BitVectorExpr, b::BitVectorExpr) -Signed modulus of BitVector a divided by BitVector b. This operator is not part of the SMT-LIB standard BitVector theory: it is implemented by Z3. It may not be available when using other solvers. +Signed modulus of BitVector a divided by BitVector b. """ smod(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop(__signfix(mod), :bvsmod, BitVectorExpr, [e1, e2]) # unique to z3 """ a >> b -Arithmetic right shift a >> b. This operator is not part of the SMT-LIB standard BitVector theory: it is implemented by Z3. It may not be available when using other solvers. +Arithmetic right shift a >> b. """ >>(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop(__signfix(>>), :bvashr, BitVectorExpr, [e1, e2]) # arithmetic shift right - unique to Z3 @@ -262,7 +262,7 @@ Bitwise and. For n>2 variables, use the and(...) notation. nor(a, b) a ⊽ b -Bitwise nor. This operator is not part of the SMT-LIB standard BitVector theory: it is implemented by Z3. It may not be available when using other solvers. When using other solvers, write ~(a | b) isntead of nor(a,b). +Bitwise nor. """ nor(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop((a,b) -> ~(a | b), :bvnor, BitVectorExpr, [e1, e2], __is_commutative=true) ⊽(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = nor(e1, e2) @@ -271,7 +271,7 @@ nor(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop((a,b) -> nand(a, b) a ⊼ b -Bitwise nand. This operator is not part of the SMT-LIB standard BitVector theory: it is implemented by Z3. It may not be available when using other solvers. When using other solvers, write ~(a & b) isntead of nand(a,b). +Bitwise nand. """ nand(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop((a,b) -> ~(a & b), :bvnand, BitVectorExpr, [e1, e2], __is_commutative=true) ⊼(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = nand(e1, e2) @@ -297,7 +297,7 @@ end xnor(a, b) xnor(a, b, c...) -Bitwise xnor. When n>2 operands are provided, xnor is left-associative (that is, `xnor(a, b, c) = reduce(xnor, [a,b,c])`. This operator is not part of the SMT-LIB standard BitVector theory: it is implemented by Z3. It may not be available when using other solvers. When using other solvers, write (a & b) | (~a & ~b). +Bitwise xnor. When n>2 operands are provided, xnor is left-associative (that is, `xnor(a, b, c) = reduce(xnor, [a,b,c])`. """ xnor(zs::Vararg{Union{T, Integer}}) where T <: AbstractBitVectorExpr = xnor(collect(zs)) # We need this declaration to enable the syntax and.([z1, z2,...,zn]) where z1, z2,...,zn are broadcast-compatible @@ -328,28 +328,28 @@ end """" slt(a::BitVectorExpr, b::BitVectorExpr) -Signed less-than. This is not the same as a < b (unsigned BitVectorExpr comparison). This operator is not part of the SMT-LIB standard BitVector theory: it is implemented by Z3. It may not be available when using other solvers. +Signed less-than. This is not the same as a < b (unsigned BitVectorExpr comparison). """ slt(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop(__signfix(>), :bvslt, BoolExpr, [e1, e2]) """ sle(a::BitVectorExpr, b::BitVectorExpr) -Signed less-than-or-equal. This is not the same as a <+ b (unsigned BitVectorExpr comparison). This operator is not part of the SMT-LIB standard BitVector theory: it is implemented by Z3. It may not be available when using other solvers. +Signed less-than-or-equal. This is not the same as a <+ b (unsigned BitVectorExpr comparison). """ sle(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop(__signfix(>=), :bvsle, BoolExpr, [e1, e2]) """ sgt(a::BitVectorExpr, b::BitVectorExpr) -Signed greater-than. This is not the same as a > b (unsigned BitVectorExpr comparison). This operator is not part of the SMT-LIB standard BitVector theory: it is implemented by Z3. It may not be available when using other solvers. +Signed greater-than. This is not the same as a > b (unsigned BitVectorExpr comparison). """ sgt(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop(__signfix(>), :bvsgt, BoolExpr, [e1, e2]) """ sge(a::BitVectorExpr, b::BitVectorExpr) -Signed greater-than-or-equal. This is not the same as a >= b (unsigned BitVectorExpr comparison). This operator is not part of the SMT-LIB standard BitVector theory: it is implemented by Z3. It may not be available when using other solvers. +Signed greater-than-or-equal. This is not the same as a >= b (unsigned BitVectorExpr comparison). """ sge(e1::AbstractBitVectorExpr, e2::AbstractBitVectorExpr) = __bvnop(__signfix(>=), :bvsge, BoolExpr, [e1, e2]) @@ -653,7 +653,7 @@ function __propagate_value!(z::AbstractBitVectorExpr) if z.op == :concat ls = getproperty.(z.children, :length) z.value = __concat(vs, ls, nextsize(z.length)) - + elseif z.op == :int2bv z.value = nextsize(z.length)(z.children[1].value) diff --git a/src/IntExpr.jl b/src/IntExpr.jl index 84e7dd8..fa868d8 100644 --- a/src/IntExpr.jl +++ b/src/IntExpr.jl @@ -408,7 +408,7 @@ Base.:*(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeric div(a, b) div(a, 2) -Returns the `Int` division expression `div(a,b)`. Note: `a` and `b` will be converted to `IntExpr`). Use dot broadcasting for vector-valued and matrix-valued expressions. +Returns the `Int` division expression `div(a,b)`. Note: `a` and `b` will be converted to `IntExpr`. Use dot broadcasting for vector-valued and matrix-valued expressions. ```julia @satvariable(a[1:n], Int) @@ -425,7 +425,7 @@ Base.div(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeri mod(a, b) mod(a, 2) -Returns the `Int` modulus expression `mod(a,b)`. Note: `a` and `b` will be converted to `IntExpr`). Use dot broadcasting for vector-valued and matrix-valued expressions. +Returns the `Int` modulus expression `mod(a,b)`. Note: `a` and `b` will be converted to `IntExpr`. Use dot broadcasting for vector-valued and matrix-valued expressions. """ Base.mod(e1::NumericInteroperableExpr, e2::NumericInteroperableExpr) = __numeric_n_ary_op([convert(IntExpr, e1), convert(IntExpr, e2)], :mod) Base.mod(e1::NumericInteroperableExpr, e2::NumericInteroperableConst) = __numeric_n_ary_op([convert(IntExpr, e1), __wrap_const(Int(floor(e2)))], :mod) @@ -435,7 +435,7 @@ Base.mod(e1::NumericInteroperableConst, e2::NumericInteroperableExpr) = __numeri a / b a / 2.0 -Returns the `Real` division expression `a/b`. Note: `a` and `b` will be converted to `RealExpr`). Use dot broadcasting for vector-valued and matrix-valued expressions. +Returns the `Real` division expression `a/b`. Note: `a` and `b` will be converted to `RealExpr`. Use dot broadcasting for vector-valued and matrix-valued expressions. ```julia @satvariable(a[1:n], Real) From 125a965b7722adf5d10d5282ec9e8138a8fbbdc6 Mon Sep 17 00:00:00 2001 From: Emiko Soroka Date: Thu, 30 Nov 2023 15:13:02 -0800 Subject: [PATCH 7/7] Increase unittest coverage for Codecov --- test/bitvector_tests.jl | 14 ++++++++++++++ test/int_real_tests.jl | 17 ++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/test/bitvector_tests.jl b/test/bitvector_tests.jl index 917d9d8..3721d20 100644 --- a/test/bitvector_tests.jl +++ b/test/bitvector_tests.jl @@ -122,6 +122,20 @@ end @test smt(a[1:8] == 0xff) == "(declare-fun a () (_ BitVec 8)) (assert (= ((_ extract 7 0) a) #xff))\n" + + @satvariable(x, BitVector, 8) + @test smt(repeat(x,2) == 0xff) == "(declare-fun x () (_ BitVec 8)) +(assert (= (concat x x) #x00ff))\n" + + @test smt(zero_extend(x,4) == 0x0) == "(declare-fun x () (_ BitVec 8)) +(assert (= ((_ zero_extend 4) x) #x000))\n" + @test smt(sign_extend(x,4) == 0x0) == "(declare-fun x () (_ BitVec 8)) +(assert (= ((_ sign_extend 4) x) #x000))\n" + + @test smt(rotate_left(x,2) == 0x0) == "(declare-fun x () (_ BitVec 8)) +(assert (= ((_ rotate_left 2) x) #x00))\n" + @test smt(rotate_right(x,2) == 0x0) == "(declare-fun x () (_ BitVec 8)) +(assert (= ((_ rotate_right 2) x) #x00))\n" end @testset "BitVector result parsing" begin diff --git a/test/int_real_tests.jl b/test/int_real_tests.jl index ed9f699..56a9047 100644 --- a/test/int_real_tests.jl +++ b/test/int_real_tests.jl @@ -10,6 +10,10 @@ using Test @satvariable(br[1:2], Real) @satvariable(cr[1:1,1:2], Real) + @satvariable(z, Bool) + @test isequal(convert(IntExpr, z), ite(z, 1, 0)) + @test isequal(convert(RealExpr, z), ite(z, 1.0, 0.0)) + a.value = 2; b[1].value = 1 @test isequal((a .< b)[1], BoolExpr(:lt, AbstractExpr[a, b[1]], false, Satisfiability.__get_hash_name(:lt, [a,b[1]]))) @test isequal((a .>= b)[1], BoolExpr(:geq, AbstractExpr[a, b[1]], true, Satisfiability.__get_hash_name(:geq, [a,b[1]]))) @@ -24,9 +28,9 @@ using Test # Construct with constants on RHS c[1,2].value = 1 c[1,1].value = 0 - @test isequal((c .>= 0)[1,1] , c[1,1] >= 0) && isequal((c .<= 0.0)[1,1] , c[1,1] <= 0.0) - @test isequal((c .== 0)[1,1] , c[1,1] == 0) - @test isequal((c .< 0)[1,1] , c[1,1] < 0) && isequal((c .> 0)[1,1] , c[1,1] > 0) + @test isequal((cr .>= 0)[1,1] , cr[1,1] >= 0) && isequal((cr .<= 0.0)[1,1] , cr[1,1] <= 0.0) + @test isequal((cr .== false)[1,1] , cr[1,1] == false) + @test isequal((cr .< false)[1,1] , cr[1,1] < false) && isequal((cr .> 0)[1,1] , cr[1,1] > 0) # Construct with constants on LHS @@ -39,6 +43,7 @@ using Test @test isequal(distinct(c[1,2], c[1,1]), c[1,2] != c[1,1]) @test distinct(3,4) && !distinct(true, true) @test isequal(distinct(b), distinct(b[2], b[1])) + @test isequal(distinct(ar, 2), distinct(ar, 2.0)) end @testset "Construct n-ary ops" begin @@ -68,15 +73,17 @@ end @test isequal(sum([a, 1.0, 1, false, b[1]]), RealExpr(:add, children, nothing, Satisfiability.__get_hash_name(:add, children, is_commutative=true))) # Sum works automatically - @test isequal(1 + a + b[1] + true, sum([1, a, b[1], true])) + @test isequal(1 + div(a, b[1]) + mod(b[1], b[2]) + true, sum([1, div(a, b[1]), mod(b[1], b[2]), true])) @test all(isequal.((a - 3).children, [a, IntExpr(:const, AbstractExpr[], 3, "const_3")])) @test all(isequal.((ar/3.0).children, [ar, RealExpr(:const, AbstractExpr[], 3., "const_3.0")])) # div, /, mod type coercion @test isequal(div(2.0, ar), div(2, to_int(ar))) + @test isequal(div(ar, 2.0), div(to_int(ar), 2)) + @test isequal(mod(ar, 3.0), mod(to_int(ar), 3)) + @test isequal(mod(3.0, ar), mod(3, to_int(ar))) @test isequal(a/2, to_real(a)/2.0) - @test isequal(mod(ar, 2.0), mod(to_int(ar), 2)) # abs rewrites to ite for non-int variables @satvariable(z, Bool)