Skip to content

Commit

Permalink
Finished adding SMT translation, parsing, and unittests. Fixed the SMT
Browse files Browse the repository at this point in the history
output parser in general.
  • Loading branch information
elsoroka committed Jun 28, 2023
1 parent 8794223 commit b648cd3
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 25 deletions.
2 changes: 0 additions & 2 deletions src/IntExpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ function __numeric_n_ary_op(es_mixed::Array, op::Symbol)
return ReturnExpr(op, es, value, __get_hash_name(op, es))
end

#Base.sum(es_mixed::Array) = __numeric_n_ary_op(es_mixed, :ADD)
#Base.prod(es_mixed::Array) = __numeric_n_ary_op(es_mixed, :MUL)

# The unsightly typing here specifies the following extensions to Base.:+
# NumericExpr + NumericExpr
Expand Down
2 changes: 1 addition & 1 deletion src/call_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function talk_to_solver(input::String, cmd)
write(pstdin, "(get-model)\n")
sleep(0.001) # IDK WHY WE NEED THIS BUT IF WE DON'T HAVE IT, pstdout HAS 0 BYTES BUFFERED
output = String(readavailable(pstdout))
satisfying_assignment = __parse_smt_output(output)
satisfying_assignment = parse_smt_output(output)
return :SAT, satisfying_assignment, proc

else
Expand Down
48 changes: 31 additions & 17 deletions src/sat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,50 @@ sat!(zs::Vararg{Union{Array{T}, T}}) where T <: BoolExpr = length(zs) > 0 ?
sat!(zs::Array) = sat!(zs...)


function __assign!(z::T, values::Dict{String, Bool}) where T <: BoolExpr
##### ASSIGNMENTS ####
# see discussion on why this is the way it is
# https://docs.julialang.org/en/v1/manual/performance-tips/#The-dangers-of-abusing-multiple-dispatch-(aka,-more-on-types-with-values-as-parameters)
# https://groups.google.com/forum/#!msg/julia-users/jUMu9A3QKQQ/qjgVWr7vAwAJ
__reductions = Dict(
:NOT => (values) -> !(values[1]),
:AND => (values) -> reduce(&, values),
:OR => (values) -> reduce(|, values),
:XOR => (values) -> reduce(xor, values),
:IMPLIES => (values) -> (values[1]) | values[2],
:IFF => (values) -> values[1] == values[2],
:ITE => (values) -> (values[1] & values[2]) | (values[1] & values[3]),
:EQ => (values) -> values[1] == values[2],
:LT => (values) -> values[1] < values[2],
:LEQ => (values) -> values[1] <= values[2],
:GT => (values) -> values[1] > values[2],
:GEQ => (values) -> values[1] >= values[2],
:ADD => (values) -> sum(values),
:SUB => (values) -> value[1] - sum(values[2:end])
:MUL => (values) -> prod(values)
:DIV => (values) -> value[1] / prod(values[2:end])
)

function __assign!(z::T, values::Dict) where T <: AbstractExpr
if z.op == :IDENTITY
if z.name keys(values)
z.value = values[z.name]
else
z.value = missing # this is better than nothing because & and | automatically skip it (three-valued logic).
end
elseif z.op == :CONST
; # CONST already has .value set so do nothing
else
map( (z) -> __assign!(z, values), z.children)
if z.op == :NOT
z.value = !(z.children[1].value)
elseif z.op == :AND
z.value = reduce(&, map((c) -> c.value, z.children))
elseif z.op == :OR
z.value = reduce(|, map((c) -> c.value, z.children))
elseif z.op == :XOR
z.value = reduce(xor, map((c) -> c.value, z.children))
elseif z.op == :IMPLIES
z.value = !(z.children[1].value) | z.children[2].value
elseif z.op == :IFF
z.value = z.children[1].value == z.children[2].value
elseif z.op == :ITE
z.value = (z.children[1].value & z.children[2].value) | (!(z.children[1].value) & z.children[3].value)
values = map( (z) -> z.value, z.children)
if z.op keys(__reductions)
z.value = __reductions[z.op](values)
else
error("Unrecognized operator $(z.op)")
@error("Unknown operator $(z.op)")
end
end
end

function __clear_assignment!(z::BoolExpr)
function __clear_assignment!(z::AbstractExpr)
z.value = nothing
if length(z.children) > 0
map(__clear_assignment!, z.children)
Expand Down
20 changes: 18 additions & 2 deletions src/smt_representation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,21 @@ function __return_type(op::Symbol, zs::Array{T}) where T <: AbstractExpr
end
end

# Return either z.name or the correct (as z.name Type) if z.name is defined for multiple types
# This multiple name misbehavior is allowed in SMT2; the expression (as z.name Type) is called a fully qualified name.
# It would arise if someone wrote something like xb = Bool("x"); xi = Int("x")
function __get_smt_name(z::AbstractExpr)
if z.op == :CONST
return string(z.value)
end
global GLOBAL_VARNAMES
appears_in = map( (t) -> z.name GLOBAL_VARNAMES[t], __EXPR_TYPES)
if sum(appears_in) > 1
return "(as $(z.name) $(__smt_typenames[typeof(z)]))"
else # easy case, one variable with z.name is defined
return z.name
end
end

"__define_n_op! is a helper function for defining the SMT statements for n-ary ops where n >= 2.
cache is a Dict where each value is an SMT statement and its key is the hash of the statement. This allows us to avoid two things:
Expand All @@ -101,7 +116,7 @@ function __define_n_op!(zs::Array{T}, op::Symbol, cache::Dict{UInt64, String}, d
fname = __get_hash_name(op, zs)
# if the expr is a :CONST it will have a value (e.g. 2 or 1.5), otherwise use its name
# This yields a list like String["z_1", "z_2", "1"].
varnames = map( (c) -> c.op != :CONST ? c.name : string(c.value), zs)
varnames = map(__get_smt_name, zs)
outname = __return_type(op, zs)

declaration = "(define-fun $fname () $outname ($(__smt_n_opnames[op]) $(join(sort(varnames), " "))))\n"
Expand All @@ -122,11 +137,12 @@ function __define_n_op!(zs::Array{T}, op::Symbol, cache::Dict{UInt64, String}, d
end
end


function __define_1_op!(z::AbstractExpr, op::Symbol, cache::Dict{UInt64, String}, depth::Int)
fname = __get_hash_name(op, z.children)
outname = __return_type(op, [z])
prop = ""
declaration = "(define-fun $fname () $outname ($(__smt_1_opnames[op]) $(z.children[1].name)))\n"
declaration = "(define-fun $fname () $outname ($(__smt_1_opnames[op]) $(__get_smt_name(z.children)))\n"
cache_key = hash(declaration)

if depth == 0 && !isa(z, BoolEx)
Expand Down
118 changes: 115 additions & 3 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ flatten(a::Array{T}) where T = reshape(a, length(a))
"Flatten nested arrays to a single expression using operator to combine them.
For example, [z1, [z2, z3], z4] with operator and returns and(z1, and(z2, z3), z4).
This is a helper function designed to be called by save! or sat!"
function __flatten_nested_exprs(operator, zs::Vararg{Union{Array{T}, T}}) where T <: AbstractExpr
function __flatten_nested_exprs(operator::Function, zs::Vararg{Union{Array{T}, T}}) where T <: AbstractExpr
# Combine the array exprs so we don't have arrays in arrays
zs = map( (z) -> typeof(z) == BoolExpr ? z : operator(z), zs)
zs = map( (z) -> isa(z, AbstractExpr) ? z : operator(z), zs)
return and(collect(zs)) # collect turns it from a tuple to an array
end

Expand Down Expand Up @@ -47,7 +47,7 @@ end


##### PARSING SMT OUTPUT #####

#=
"Utility function for parsing SMT output. Split lines based on parentheses"
function __split_line(output, ptr)
stack = 0
Expand Down Expand Up @@ -105,4 +105,116 @@ function __parse_smt_output(output::String)
end
# line n is the closing )
return values
end
=#

##### NEW OUTPUT PARSER #####

# Given a string consisting of a set of statements (statement-1) \n(statement-2) etc, split into an array of strings, stripping \n and ().
# Split one level only, so "(a(b))(c)(d)" returns ["a(b)", "c", "d"]
# A mismatched left parenthesis like "(a)(bb" generates a warning and the output ["a", "b"]
# A mismatched right parenthesis like "(a)b)" generates no warning and the output ["a"]
function __split_statements(input::String)
statements = String[]
ptr = findfirst('(', input)
if isnothing(ptr)
@error "Unable to split string\n\"$input\""
return [input]
end
# if we get here we found a (
while !isnothing(ptr)
stack = 1 # stack tracks how many levels of () there are
start = ptr
while stack > 0
l = findnext('(', input, ptr+1)
r = findnext(')', input, ptr+1)
l = isnothing(l) ? length(input) : l
if isnothing(r)
@warn "( at character $ptr without matching )"
r = length(input)
end

# if we found a left parenthesis, add one level and if it's right, subtract one level
if l < r
stack += 1
ptr = l
else
stack -= 1
ptr = r
end
end

push!(statements, input[start+1:ptr-1]) # +1 and -1 strips the ( and )
ptr = findnext('(', input, ptr+1) # will be nothing if no more (
end
return statements
end

# Given a line like "define-fun X () Bool|Int|Real (op x1 x2 ...)"
# skip it
# Given a line like "define-fun X () Bool|Int|Real value|(- value)"
# where value is true|false|int|float, return the name X and the value
function __parse_line(line::String)
original_line = deepcopy(line)
# filter ' ' and '\n'
line = filter((c) -> c != ' ' && c != '\n', line)
ptr = 10 # line always starts with define-fun so we can skip that
name = line[ptr+1:findnext('(', line, ptr+1)-1]
ptr += length(name)
ptr = findnext(')', line, ptr+1) # skip the next part ()
# figure out what the return type is
return_type = nothing
if startswith(line[ptr+1:end], "Bool")
return_type = Bool
ptr += 4
elseif startswith(line[ptr+1:end], "Int")
return_type = Int64
ptr += 3
elseif startswith(line[ptr+1:end], "Real")
return_type = Float64
ptr += 4
else
@error "Unable to parse return type of \"$original_line\""
end
try
value = __parse_value(return_type, line[ptr+1:end])
return name, value # value may be nothing if it's a function and not a variable
catch
@error "Unable to parse value of type $return_type in \"$original_line\""
end
end

# Determine whether line represents the value of a variable (ex: "0", "true", "(- 2)")
# or a constructed function (ex: "(+ 1 a)", "(+ 2 a b"), "(>= (+ 1 a) b)")
# Return nothing if it's a function and the value if it's a variable
function __parse_value(value_type::Type, line::String)
l = findfirst('(', line)
if !isnothing(l) # there is a parenthesis
# the only valid thing to see here is -
if line[l+1] != '-'
# now we know it's a function and not a variable
return nothing
end
# trim the ()
line = line[l+1:findlast(')', line)-1]
end
return parse(value_type, line)
end

function parse_smt_output(output::String)
assignments = Dict()
# recall the whole output will be surrounded by ()
output = __split_statements(output)
if length(output) > 1 # something is wrong!
@error "Unable to parse output\n\"$output\""
return assignments
end
# now we've cleared the outer (), so iterating will go over each line in the model
for line in __split_statements(output[1])
(name, value) = __parse_line(line)
if !isnothing(value)
assignments[name] = value
end
end
return assignments
end
41 changes: 41 additions & 0 deletions test/int_parse_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,45 @@ status = sat!(expr)
@test value(a) == 0
@test value(b) == -2

end

@testset "Parse some z3 output with ints and floats" begin
output = "(
(define-fun GEQ_d3e5e06dff9812ca () Bool
(>= (+ 1 b) b))
(define-fun ADD_f0a93f0b97da1ab2 () Int
(+ 1 b))
(define-fun LEQ_d476c845a7be63a () Bool
(<= (+ 2 a b) a))
(define-fun AND_20084a5e2cc43534 () Bool
(and (>= (+ 1 b) b) (<= (+ 2 a b) a)))
(define-fun b () Int
(- 2))
(define-fun ADD_99dce5c325207b7 () Int
(+ 2 a b))
(define-fun a () Int
0)
)"
result = BooleanSatisfiability.parse_smt_output(output)
@test result == Dict("b" => -2, "a" => 0)

output = "((define-fun b () Real (- 2.5))
(define-fun ADD_99dce5c325207b7 () Real
(+ 2 a b))
(define-fun a () Real
0.0)
))"
result = BooleanSatisfiability.parse_smt_output(output)
@test result == Dict("b" => -2.5, "a" => 0.0)
end

# Who would do this?? But it's supported anyway.
@testset "Define fully-qualified names" begin
a = Int("a")
ar = Real("a")
hashname = BooleanSatisfiability.__get_hash_name(:ADD, [a, ar])
@test smt(a + ar) == "(declare-const a Int)
(declare-const a Real)
(define-fun $hashname () Real (+ (as a Int) (as a Real)))
"
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ include("smt_representation_tests.jl")
# Calling Z3 and interpreting the result
include("solver_interface_tests.jl")

# Test with int and real problems
include("int_parse_tests.jl")

# Extra: Check that defining duplicate variables yields a warning
@testset "Duplicate variable warning" begin
SET_DUPLICATE_NAME_WARNING!(true)
Expand Down

0 comments on commit b648cd3

Please sign in to comment.