Skip to content

Commit

Permalink
Added test coverage and fixed bugs for propagating values with __assign!
Browse files Browse the repository at this point in the history
  • Loading branch information
elsoroka committed Jul 1, 2023
1 parent ad872c4 commit 422bcb5
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 49 deletions.
14 changes: 12 additions & 2 deletions src/BoolExpr.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Base.length, Base.size, Base.show, Base.string, Base.==, Base.broadcastable
import Base.length, Base.size, Base.show, Base.string, Base.isequal, Base.hash, Base.broadcastable

##### TYPE DEFINITIONS #####

Expand Down Expand Up @@ -74,6 +74,16 @@ function Base.string(expr::AbstractExpr, indent=0)::String
end

"Test equality of two BoolExprs."
function (==)(expr1::AbstractExpr, expr2::AbstractExpr)
function Base.isequal(expr1::AbstractExpr, expr2::AbstractExpr)
return (expr1.op == expr2.op) && all(expr1.value .== expr2.value) && (expr1.name == expr2.name) && (__is_permutation(expr1.children, expr2.children))
end

# Required for isequal apparently, since isequal(expr1, expr2) implies hash(expr1) == hash(expr2).
function Base.hash(expr::AbstractExpr)
return hash("$(show(expr))")
end

# Overload because Base.in uses == which se used to construct equality expressions
function Base.in(expr::T, exprs::Array{T}) where T <: AbstractExpr
return any(isequal.(expr, exprs))
end
129 changes: 113 additions & 16 deletions src/IntExpr.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
import Base.Int, Base.Real
import Base.<, Base.<=, Base.>, Base.<=, Base.+, Base.-, Base.*, Base./# Base.sum, Base.prod
import Base.<, Base.<=, Base.>, Base.<=, Base.+, Base.-, Base.*, Base./, Base.==

abstract type NumericExpr <: AbstractExpr end


mutable struct IntExpr <: NumericExpr
op :: Symbol
children :: Array{AbstractExpr}
value :: Union{Int, Nothing, Missing}
value :: Union{Int, Bool, Nothing, Missing}
name :: String
end

"""
Int("a")
Construct a single Int variable with name "a".
```julia
Int(n, "a")
Int(m, n, "a")
```
Construct a vector-valued or matrix-valued Int variable with name "a".
Vector and matrix-valued Ints use Julia's built-in array functionality: calling `Int(n,"a")` returns a `Vector{IntExpr}`, while calling `Int(m, n, "a")` returns a `Matrix{IntExpr}`.
"""
function Base.Int(name::String) :: IntExpr
# This unsightly bit enables warning when users define two variables with the same string name.
global GLOBAL_VARNAMES
Expand All @@ -29,10 +43,24 @@ Int(m::Int, n::Int, name::String) :: Matrix{IntExpr} = IntExpr[Int("$(name)_$(i)
mutable struct RealExpr <: NumericExpr
op :: Symbol
children :: Array{AbstractExpr}
value :: Union{Float64, Nothing, Missing}
value :: Union{Float64, Bool, Nothing, Missing}
name :: String
end

"""
Real("r")
Construct a single Int variable with name "r".
```julia
Real(n, "r")
Real(m, n, "r")
```
Construct a vector-valued or matrix-valued Real variable with name "r".
Vector and matrix-valued Reals use Julia's built-in array functionality: calling `Real(n,"a")` returns a `Vector{RealExpr}`, while calling `Real(m, n, "r")` returns a `Matrix{RealExpr}`.
"""
function Base.Real(name::String) :: RealExpr
# This unsightly bit enables warning when users define two variables with the same string name.
global GLOBAL_VARNAMES
Expand All @@ -56,28 +84,84 @@ NumericInteroperable = Union{NumericInteroperableExpr, NumericInteroperableConst
__wrap_const(c::Float64) = RealExpr(:CONST, AbstractExpr[], c, "const_$c")
__wrap_const(c::Union{Int, Bool}) = IntExpr(:CONST, AbstractExpr[], c, "const_$c")


##### COMPARISON OPERATIONS ####
# These return Boolean values. In the SMT dialect we would say they have sort Bool
# See figure 3.3 in the SMT-LIB standard.

"""
a < b
a < 0
Returns the Boolean expression a < b. Use dot broadcasting for vector-valued and matrix-valued Boolean expressions.
```julia
a = Int(n, "a")
b = Int(n, m, "b")
a .< b
z = Bool("z")
a .< z
```
"""
function Base.:<(e1::AbstractExpr, e2::AbstractExpr)
value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value < e2.value
name = __get_hash_name(:LT, [e1, e2])
return BoolExpr(:LT, [e1, e2], value, name)
end

"""
a <= b
a <= 0
Returns the Boolean expression a <= b. Use dot broadcasting for vector-valued and matrix-valued Boolean expressions.
```julia
a = Int(n, "a")
b = Int(n, m, "b")
a .<= b
z = Bool("z")
a .<= z
```
"""
function Base.:<=(e1::AbstractExpr, e2::AbstractExpr)
value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value <= e2.value
name = __get_hash_name(:LEQ, [e1, e2])
return BoolExpr(:LEQ, [e1, e2], value, name)
end

"""
a >= b
a >= 0
Returns the Boolean expression a >= b. Use dot broadcasting for vector-valued and matrix-valued Boolean expressions.
```julia
a = Int(n, "a")
b = Int(n, m, "b")
a .>= b
z = Bool("z")
a .>= z
```
"""
function Base.:>=(e1::AbstractExpr, e2::AbstractExpr)
value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value >= e2.value
name = __get_hash_name(:GEQ, [e1, e2])
return BoolExpr(:GEQ, [e1, e2], value, name)
end

"""
a > b
a > 0
Returns the Boolean expression a > b. Use dot broadcasting for vector-valued and matrix-valued Boolean expressions.
```julia
a = Int(n, "a")
b = Int(n, m, "b")
a .> b
z = Bool("z")
a .> z
```
"""
function Base.:>(e1::AbstractExpr, e2::AbstractExpr)
value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value > e2.value
name = __get_hash_name(:GT, [e1, e2])
Expand All @@ -90,7 +174,7 @@ end
# We can't swap the definitions eq and (==) because that breaks Base behavior.
# For example, if (==) generates an equality constraint instead of making a Boolean, you can't write z ∈ [z1,...,zn].

function eq(e1::T, e2::T) where T <: AbstractExpr
function Base.:(==)(e1::T, e2::T) where T <: AbstractExpr
value = isnothing(e1.value) || isnothing(e2.value) ? nothing : e1.value == e2.value
name = __get_hash_name(:EQ, [e1, e2])
return BoolExpr(:EQ, [e1, e2], value, name)
Expand All @@ -107,12 +191,25 @@ Base.:<(e1::NumericInteroperableConst, e2::AbstractExpr) = wrap_const(e1) < e2
Base.:<=(e1::AbstractExpr, e2::NumericInteroperableConst) = e1 <= __wrap_const(e2)
Base.:<=(e1::NumericInteroperableConst, e2::AbstractExpr) = wrap_const(e1) <= e2

eq(e1::AbstractExpr, e2::NumericInteroperableConst) = e1 == __wrap_const(e2)
eq(e1::NumericInteroperableConst, e2::AbstractExpr) = wrap_const(e1) == e2
eq(e1::AbstractExpr, e2::NumericInteroperableConst) = eq(e1, __wrap_const(e2))
eq(e1::NumericInteroperableConst, e2::AbstractExpr) = eq(wrap_const(e1), e2)


##### UNARY OPERATIONS #####
# Ok there's only one... negation.
"""
-(a::IntExpr)
-(r::RealExpr)
Return the negative of an Int or Real expression.
```julia
a = Int(n, "a")
-a # this works
b = Int(n, m, "b")
-b # this also works
```
"""
Base.:-(e::IntExpr) = IntExpr(:NEG, IntExpr[e,], isnothing(e.value) ? nothing : -e.value, __get_hash_name(:NEG, [e,]))
Base.:-(e::RealExpr) = RealExpr(:NEG, RealExpr[e,], isnothing(e.value) ? nothing : -e.value, __get_hash_name(:NEG, [e,]))

Expand All @@ -125,24 +222,24 @@ Base.:-(es::Array{T}) where T <: NumericExpr = .-es
# See figure 3.3 in the SMT-LIB standard.

# If literal is != 0, add a :CONST expr to es representing literal
function add_const!(es::Array{T}, literal::Real) where T <: AbstractExpr
function __add_const!(es::Array{T}, literal::Real) where T <: AbstractExpr
if literal != 0
const_expr = isa(literal, Float64) ? RealExpr(:CONST, AbstractExpr[], literal, "const_$literal") : IntExpr(:CONST, AbstractExpr[], literal, "const_$literal")
push!(es, const_expr)
end
end

# If there is more than one :CONST expr in es, merge them into one
function merge_const!(es::Array{T}) where T <: AbstractExpr
function __merge_const!(es::Array{T}) where T <: AbstractExpr
const_exprs = filter( (e) -> e.op == :CONST, es)
if length(const_exprs) > 1
filter!( (e) -> e.op != :CONST, es)
add_const!(es, sum(getproperty.(const_exprs, :value)))
__add_const!(es, sum(getproperty.(const_exprs, :value)))
end
end

# This is NOT a recursive function. It will only unnest one level.
function unnest(es::Array{T}, op::Symbol) where T <: AbstractExpr
function __unnest(es::Array{T}, op::Symbol) where T <: AbstractExpr
# this is all the child operators that aren't CONST or IDENTITY
child_operators = filter( (op) -> op != :IDENTITY && op != :CONST, getproperty.(es, :op))

Expand All @@ -162,20 +259,20 @@ function __numeric_n_ary_op(es_mixed::Array, op::Symbol)
literal = length(literals) > 0 ? sum(literals) : 0

# flatten nestings, this prevents unsightly things like and(x, and(y, and(z, true)))
es = unnest(es, op)
es = __unnest(es, op)
# now we are guaranteed all es are valid exprs and all literals have been condensed to one
# hack to store literals
add_const!(es, literal)
__add_const!(es, literal)

# Now it is possible we have several CONST exprs. This occurs if, for example, one writes 1 + a + true
# TO clean up, we should merge the CONST exprs
merge_const!(es)
__merge_const!(es)

# Now everything is in es and we are all cleaned up.
# Determine return expr type. Note that / promotes to RealExpr because the SMT theory of integers doesn't include it
ReturnExpr = any(isa.(es, RealExpr)) || op == :DIV ? RealExpr : IntExpr

value = any(isnothing.(getproperty.(es, :value))) ? nothing : sum(values)
value = any(isnothing.(getproperty.(es, :value))) ? nothing : sum(getproperty.(es, :value))
return ReturnExpr(op, es, value, __get_hash_name(op, es))
end

Expand Down
10 changes: 5 additions & 5 deletions src/sat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ __reductions = Dict(
:NOT => (values) -> !(values[1]),
:AND => (values) -> reduce(&, values),
:OR => (values) -> reduce(|, values),
:XOR => (values) -> reduce(xor, values),
:IMPLIES => (values) -> (values[1]) | values[2],
:XOR => (values) -> sum(values) == 1,
:IMPLIES => (values) -> !(values[1]) | values[2],
:IFF => (values) -> values[1] == values[2],
:ITE => (values) -> (values[1] & values[2]) | (values[1] & values[3]),
:EQ => (values) -> values[1] == values[2],
Expand All @@ -53,9 +53,9 @@ __reductions = Dict(
:GT => (values) -> values[1] > values[2],
:GEQ => (values) -> values[1] >= values[2],
:ADD => (values) -> sum(values),
:SUB => (values) -> value[1] - sum(values[2:end])
:MUL => (values) -> prod(values)
:DIV => (values) -> value[1] / prod(values[2:end])
:SUB => (values) -> values[1] - sum(values[2:end]) ,
:MUL => (values) -> prod(values),
:DIV => (values) -> values[1] / prod(values[2:end]),
)

function __assign!(z::T, values::Dict) where T <: AbstractExpr
Expand Down
52 changes: 26 additions & 26 deletions test/boolean_operation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,39 +51,39 @@ end
z23 = Bool(2,3, "z23")

# and(z) = z and or(z) = z
@test and([z1[1]]) == z1[1]
@test isequal(and([z1[1]]), z1[1])

@test or([z23[1]]) == z23[1]
@test isequal(or([z23[1]]), z23[1])

# Can construct with 2 exprs
@test all( (z1 .∧ z32)[1].children .== [z1[1], z32[1]] )
@test all( isequal.((z1 .∧ z32)[1].children, [z1[1], z32[1]] ))
@test (z1 .∧ z32)[1].name == BooleanSatisfiability.__get_hash_name(:AND, [z1[1], z32[1]])
@test all( (z1 .∨ z32)[2,1].children .== [z1[1], z32[2,1]] )
@test all( isequal.((z1 .∨ z32)[2,1].children, [z1[1], z32[2,1]] ))
@test (z1 .∨ z32)[1].name == BooleanSatisfiability.__get_hash_name(:OR, [z1[1], z32[1]])

# Can construct with N>2 exprs
or_N = or.(z1, z12, z32)
and_N = and.(z1, z12, z32)

@test all( or_N[3,2].children .== [z1[1], z12[1,2], z32[3,2]] )
@test all( isequal.(or_N[3,2].children, [z1[1], z12[1,2], z32[3,2]] ))
@test and_N[1].name == BooleanSatisfiability.__get_hash_name(:AND, and_N[1].children)

@test all( or_N[1].children .== [z1[1], z12[1], z32[1]] )
@test all( isequal.(or_N[1].children, [z1[1], z12[1], z32[1]] ))
@test or_N[1].name == BooleanSatisfiability.__get_hash_name(:OR, and_N[1].children)

# Can construct negation
@test (¬z32)[1].children == [z32[1]]
@test isequal((¬z32)[1].children, [z32[1]])

# Can construct Implies
@test (z1 .⟹ z1)[1].children == [z1[1], z1[1]]
@test isequal((z1 .⟹ z1)[1].children, [z1[1], z1[1]])

# Can construct all() and any() statements
@test any(z1 .∨ z12) == BoolExpr(:OR, [z1[1], z12[1,1], z12[1,2]], nothing, BooleanSatisfiability.__get_hash_name(:OR, [z1 z12]))
@test all(z1 .∧ z12) == BoolExpr(:AND, [z1[1], z12[1,1], z12[1,2]], nothing, BooleanSatisfiability.__get_hash_name(:AND, [z1 z12]))
@test isequal(any(z1 .∨ z12), BoolExpr(:OR, [z1[1], z12[1,1], z12[1,2]], nothing, BooleanSatisfiability.__get_hash_name(:OR, [z1 z12])))
@test isequal(all(z1 .∧ z12), BoolExpr(:AND, [z1[1], z12[1,1], z12[1,2]], nothing, BooleanSatisfiability.__get_hash_name(:AND, [z1 z12])))

# mismatched all() and any()
@test any(z1 .∧ z12) == BoolExpr(:OR, [z1[1] z12[1,1], z1[1] z12[1,2]], nothing, BooleanSatisfiability.__get_hash_name(:OR, z1.∧ z12))
@test and(z1 .∨ z12) == BoolExpr(:AND, [z1[1] z12[1,1], z1[1] z12[1,2]], nothing, BooleanSatisfiability.__get_hash_name(:AND, z1.∨ z12))
@test isequal(any(z1 .∧ z12), BoolExpr(:OR, [z1[1] z12[1,1], z1[1] z12[1,2]], nothing, BooleanSatisfiability.__get_hash_name(:OR, z1.∧ z12)))
@test isequal(and(z1 .∨ z12), BoolExpr(:AND, [z1[1] z12[1,1], z1[1] z12[1,2]], nothing, BooleanSatisfiability.__get_hash_name(:AND, z1.∨ z12)))
end

@testset "Additional operations" begin
Expand All @@ -92,24 +92,24 @@ end
z12 = Bool(1,2, "z12")

# xor
@test all(xor.(z1, z12) .== BoolExpr[xor(z1[1], z12[1,1]) xor(z1[1], z12[1,2])])
@test all(isequal.(xor.(z1, z12), BoolExpr[xor(z1[1], z12[1,1]) xor(z1[1], z12[1,2])]))
# weird cases
@test all(xor(z1) .== z1)
@test all(isequal.(xor(z1), z1))
@test xor(true, true, z) == false
@test xor(true, false, z) == ¬z
@test all(xor.(false, z, z1) .== xor.(z, z1))
@test isequal(xor(true, false, z), ¬z)
@test all(isequal.(xor.(false, z, z1), xor.(z, z1)))
# n case
@test all(xor.(z, z1, z12) .== BoolExpr[xor(z, z1[1], z12[1,1]) xor(z, z1[1], z12[1,2])])
@test all(isequal.(xor.(z, z1, z12, BoolExpr[xor(z, z1[1], z12[1,1]) xor(z, z1[1], z12[1,2])])))

# iff
@test all(iff.(z1, z12) .== BoolExpr[ iff(z1[1], z12[1,1]) iff(z1[1], z12[1,2]) ])
@test all(isequal.(iff.(z1, z12), BoolExpr[ iff(z1[1], z12[1,1]) iff(z1[1], z12[1,2]) ]))

# ite (if-then-else)
@test all( ite.(z,z1, z12) .== BoolExpr[ ite(z, z1[1], z12[1,1]) ite(z, z1[1], z12[1,2]) ])
@test all(isequal.( ite.(z,z1, z12), BoolExpr[ ite(z, z1[1], z12[1,1]) ite(z, z1[1], z12[1,2]) ]))

# mixed all and any
@test all([or(z, z1[1]), and(z, true)]) == and(or(z, z1[1]), z)
@test any([and(z, z1[1]), or(z, false)]) == or(and(z, z1[1]), z)
@test all(isequal.([or(z, z1[1]), and(z, true)]), and(or(z, z1[1]), z))
@test any(isequal.([and(z, z1[1]), or(z, false)]), or(and(z, z1[1]), z))
end

@testset "Operations with 1D literals and 1D exprs" begin
Expand All @@ -122,14 +122,14 @@ end
@test implies(false, false)

# Can operate on mixed literals and BoolExprs
@test and(true, z) == z
@test isequal(and(true, z), z)
@test and(z, false) == false
@test or(true, z) == true
@test or(z, false, false) == z
@test implies(z, false) == ¬z #or(¬z, false) == ¬z
@test implies(true, z) == z
@test isequal(or(z, false, false), z)
@test isequal(implies(z, false), ¬z) #or(¬z, false) == ¬z
@test isequal(implies(true, z), z)
end

# LEFT OFF HERE

@testset "Operations with 1D literals and nxm exprs" begin
z = Bool(2,3,"z")
Expand Down
Loading

0 comments on commit 422bcb5

Please sign in to comment.