From b648cd36f2f05aeb8175f9151548c60bd7c6de22 Mon Sep 17 00:00:00 2001 From: Emiko Soroka Date: Wed, 28 Jun 2023 12:21:58 -0700 Subject: [PATCH] Finished adding SMT translation, parsing, and unittests. Fixed the SMT output parser in general. --- src/IntExpr.jl | 2 - src/call_solver.jl | 2 +- src/sat.jl | 48 ++++++++++------ src/smt_representation.jl | 20 ++++++- src/utilities.jl | 118 +++++++++++++++++++++++++++++++++++++- test/int_parse_tests.jl | 41 +++++++++++++ test/runtests.jl | 3 + 7 files changed, 209 insertions(+), 25 deletions(-) diff --git a/src/IntExpr.jl b/src/IntExpr.jl index 9ccfb7c..7550a6a 100644 --- a/src/IntExpr.jl +++ b/src/IntExpr.jl @@ -179,8 +179,6 @@ function __numeric_n_ary_op(es_mixed::Array, op::Symbol) return ReturnExpr(op, es, value, __get_hash_name(op, es)) end -#Base.sum(es_mixed::Array) = __numeric_n_ary_op(es_mixed, :ADD) -#Base.prod(es_mixed::Array) = __numeric_n_ary_op(es_mixed, :MUL) # The unsightly typing here specifies the following extensions to Base.:+ # NumericExpr + NumericExpr diff --git a/src/call_solver.jl b/src/call_solver.jl index 92ce035..bd10973 100644 --- a/src/call_solver.jl +++ b/src/call_solver.jl @@ -36,7 +36,7 @@ function talk_to_solver(input::String, cmd) write(pstdin, "(get-model)\n") sleep(0.001) # IDK WHY WE NEED THIS BUT IF WE DON'T HAVE IT, pstdout HAS 0 BYTES BUFFERED output = String(readavailable(pstdout)) - satisfying_assignment = __parse_smt_output(output) + satisfying_assignment = parse_smt_output(output) return :SAT, satisfying_assignment, proc else diff --git a/src/sat.jl b/src/sat.jl index 6e69928..2551269 100644 --- a/src/sat.jl +++ b/src/sat.jl @@ -35,36 +35,50 @@ sat!(zs::Vararg{Union{Array{T}, T}}) where T <: BoolExpr = length(zs) > 0 ? sat!(zs::Array) = sat!(zs...) -function __assign!(z::T, values::Dict{String, Bool}) where T <: BoolExpr +##### ASSIGNMENTS #### +# see discussion on why this is the way it is +# https://docs.julialang.org/en/v1/manual/performance-tips/#The-dangers-of-abusing-multiple-dispatch-(aka,-more-on-types-with-values-as-parameters) +# https://groups.google.com/forum/#!msg/julia-users/jUMu9A3QKQQ/qjgVWr7vAwAJ +__reductions = Dict( + :NOT => (values) -> !(values[1]), + :AND => (values) -> reduce(&, values), + :OR => (values) -> reduce(|, values), + :XOR => (values) -> reduce(xor, values), + :IMPLIES => (values) -> (values[1]) | values[2], + :IFF => (values) -> values[1] == values[2], + :ITE => (values) -> (values[1] & values[2]) | (values[1] & values[3]), + :EQ => (values) -> values[1] == values[2], + :LT => (values) -> values[1] < values[2], + :LEQ => (values) -> values[1] <= values[2], + :GT => (values) -> values[1] > values[2], + :GEQ => (values) -> values[1] >= values[2], + :ADD => (values) -> sum(values), + :SUB => (values) -> value[1] - sum(values[2:end]) + :MUL => (values) -> prod(values) + :DIV => (values) -> value[1] / prod(values[2:end]) +) + +function __assign!(z::T, values::Dict) where T <: AbstractExpr if z.op == :IDENTITY if z.name ∈ keys(values) z.value = values[z.name] else z.value = missing # this is better than nothing because & and | automatically skip it (three-valued logic). end + elseif z.op == :CONST + ; # CONST already has .value set so do nothing else map( (z) -> __assign!(z, values), z.children) - if z.op == :NOT - z.value = !(z.children[1].value) - elseif z.op == :AND - z.value = reduce(&, map((c) -> c.value, z.children)) - elseif z.op == :OR - z.value = reduce(|, map((c) -> c.value, z.children)) - elseif z.op == :XOR - z.value = reduce(xor, map((c) -> c.value, z.children)) - elseif z.op == :IMPLIES - z.value = !(z.children[1].value) | z.children[2].value - elseif z.op == :IFF - z.value = z.children[1].value == z.children[2].value - elseif z.op == :ITE - z.value = (z.children[1].value & z.children[2].value) | (!(z.children[1].value) & z.children[3].value) + values = map( (z) -> z.value, z.children) + if z.op ∈ keys(__reductions) + z.value = __reductions[z.op](values) else - error("Unrecognized operator $(z.op)") + @error("Unknown operator $(z.op)") end end end -function __clear_assignment!(z::BoolExpr) +function __clear_assignment!(z::AbstractExpr) z.value = nothing if length(z.children) > 0 map(__clear_assignment!, z.children) diff --git a/src/smt_representation.jl b/src/smt_representation.jl index 3b29003..7234ea5 100644 --- a/src/smt_representation.jl +++ b/src/smt_representation.jl @@ -85,6 +85,21 @@ function __return_type(op::Symbol, zs::Array{T}) where T <: AbstractExpr end end +# Return either z.name or the correct (as z.name Type) if z.name is defined for multiple types +# This multiple name misbehavior is allowed in SMT2; the expression (as z.name Type) is called a fully qualified name. +# It would arise if someone wrote something like xb = Bool("x"); xi = Int("x") +function __get_smt_name(z::AbstractExpr) + if z.op == :CONST + return string(z.value) + end + global GLOBAL_VARNAMES + appears_in = map( (t) -> z.name ∈ GLOBAL_VARNAMES[t], __EXPR_TYPES) + if sum(appears_in) > 1 + return "(as $(z.name) $(__smt_typenames[typeof(z)]))" + else # easy case, one variable with z.name is defined + return z.name + end +end "__define_n_op! is a helper function for defining the SMT statements for n-ary ops where n >= 2. cache is a Dict where each value is an SMT statement and its key is the hash of the statement. This allows us to avoid two things: @@ -101,7 +116,7 @@ function __define_n_op!(zs::Array{T}, op::Symbol, cache::Dict{UInt64, String}, d fname = __get_hash_name(op, zs) # if the expr is a :CONST it will have a value (e.g. 2 or 1.5), otherwise use its name # This yields a list like String["z_1", "z_2", "1"]. - varnames = map( (c) -> c.op != :CONST ? c.name : string(c.value), zs) + varnames = map(__get_smt_name, zs) outname = __return_type(op, zs) declaration = "(define-fun $fname () $outname ($(__smt_n_opnames[op]) $(join(sort(varnames), " "))))\n" @@ -122,11 +137,12 @@ function __define_n_op!(zs::Array{T}, op::Symbol, cache::Dict{UInt64, String}, d end end + function __define_1_op!(z::AbstractExpr, op::Symbol, cache::Dict{UInt64, String}, depth::Int) fname = __get_hash_name(op, z.children) outname = __return_type(op, [z]) prop = "" - declaration = "(define-fun $fname () $outname ($(__smt_1_opnames[op]) $(z.children[1].name)))\n" + declaration = "(define-fun $fname () $outname ($(__smt_1_opnames[op]) $(__get_smt_name(z.children)))\n" cache_key = hash(declaration) if depth == 0 && !isa(z, BoolEx) diff --git a/src/utilities.jl b/src/utilities.jl index 4f78292..3064ff2 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -5,9 +5,9 @@ flatten(a::Array{T}) where T = reshape(a, length(a)) "Flatten nested arrays to a single expression using operator to combine them. For example, [z1, [z2, z3], z4] with operator and returns and(z1, and(z2, z3), z4). This is a helper function designed to be called by save! or sat!" -function __flatten_nested_exprs(operator, zs::Vararg{Union{Array{T}, T}}) where T <: AbstractExpr +function __flatten_nested_exprs(operator::Function, zs::Vararg{Union{Array{T}, T}}) where T <: AbstractExpr # Combine the array exprs so we don't have arrays in arrays - zs = map( (z) -> typeof(z) == BoolExpr ? z : operator(z), zs) + zs = map( (z) -> isa(z, AbstractExpr) ? z : operator(z), zs) return and(collect(zs)) # collect turns it from a tuple to an array end @@ -47,7 +47,7 @@ end ##### PARSING SMT OUTPUT ##### - +#= "Utility function for parsing SMT output. Split lines based on parentheses" function __split_line(output, ptr) stack = 0 @@ -105,4 +105,116 @@ function __parse_smt_output(output::String) end # line n is the closing ) return values +end +=# + +##### NEW OUTPUT PARSER ##### + +# Given a string consisting of a set of statements (statement-1) \n(statement-2) etc, split into an array of strings, stripping \n and (). +# Split one level only, so "(a(b))(c)(d)" returns ["a(b)", "c", "d"] +# A mismatched left parenthesis like "(a)(bb" generates a warning and the output ["a", "b"] +# A mismatched right parenthesis like "(a)b)" generates no warning and the output ["a"] +function __split_statements(input::String) + statements = String[] + ptr = findfirst('(', input) + if isnothing(ptr) + @error "Unable to split string\n\"$input\"" + return [input] + end + # if we get here we found a ( + while !isnothing(ptr) + stack = 1 # stack tracks how many levels of () there are + start = ptr + while stack > 0 + l = findnext('(', input, ptr+1) + r = findnext(')', input, ptr+1) + l = isnothing(l) ? length(input) : l + if isnothing(r) + @warn "( at character $ptr without matching )" + r = length(input) + end + + # if we found a left parenthesis, add one level and if it's right, subtract one level + if l < r + stack += 1 + ptr = l + else + stack -= 1 + ptr = r + end + end + + push!(statements, input[start+1:ptr-1]) # +1 and -1 strips the ( and ) + ptr = findnext('(', input, ptr+1) # will be nothing if no more ( + end + return statements +end + +# Given a line like "define-fun X () Bool|Int|Real (op x1 x2 ...)" +# skip it +# Given a line like "define-fun X () Bool|Int|Real value|(- value)" +# where value is true|false|int|float, return the name X and the value +function __parse_line(line::String) + original_line = deepcopy(line) + # filter ' ' and '\n' + line = filter((c) -> c != ' ' && c != '\n', line) + ptr = 10 # line always starts with define-fun so we can skip that + name = line[ptr+1:findnext('(', line, ptr+1)-1] + ptr += length(name) + ptr = findnext(')', line, ptr+1) # skip the next part () + # figure out what the return type is + return_type = nothing + if startswith(line[ptr+1:end], "Bool") + return_type = Bool + ptr += 4 + elseif startswith(line[ptr+1:end], "Int") + return_type = Int64 + ptr += 3 + elseif startswith(line[ptr+1:end], "Real") + return_type = Float64 + ptr += 4 + else + @error "Unable to parse return type of \"$original_line\"" + end + try + value = __parse_value(return_type, line[ptr+1:end]) + return name, value # value may be nothing if it's a function and not a variable + catch + @error "Unable to parse value of type $return_type in \"$original_line\"" + end +end + +# Determine whether line represents the value of a variable (ex: "0", "true", "(- 2)") +# or a constructed function (ex: "(+ 1 a)", "(+ 2 a b"), "(>= (+ 1 a) b)") +# Return nothing if it's a function and the value if it's a variable +function __parse_value(value_type::Type, line::String) + l = findfirst('(', line) + if !isnothing(l) # there is a parenthesis + # the only valid thing to see here is - + if line[l+1] != '-' + # now we know it's a function and not a variable + return nothing + end + # trim the () + line = line[l+1:findlast(')', line)-1] + end + return parse(value_type, line) +end + +function parse_smt_output(output::String) + assignments = Dict() + # recall the whole output will be surrounded by () + output = __split_statements(output) + if length(output) > 1 # something is wrong! + @error "Unable to parse output\n\"$output\"" + return assignments + end + # now we've cleared the outer (), so iterating will go over each line in the model + for line in __split_statements(output[1]) + (name, value) = __parse_line(line) + if !isnothing(value) + assignments[name] = value + end + end + return assignments end \ No newline at end of file diff --git a/test/int_parse_tests.jl b/test/int_parse_tests.jl index cd5fd59..d16bdeb 100644 --- a/test/int_parse_tests.jl +++ b/test/int_parse_tests.jl @@ -25,4 +25,45 @@ status = sat!(expr) @test value(a) == 0 @test value(b) == -2 +end + +@testset "Parse some z3 output with ints and floats" begin + output = "( + (define-fun GEQ_d3e5e06dff9812ca () Bool + (>= (+ 1 b) b)) + (define-fun ADD_f0a93f0b97da1ab2 () Int + (+ 1 b)) + (define-fun LEQ_d476c845a7be63a () Bool + (<= (+ 2 a b) a)) + (define-fun AND_20084a5e2cc43534 () Bool + (and (>= (+ 1 b) b) (<= (+ 2 a b) a))) + (define-fun b () Int + (- 2)) + (define-fun ADD_99dce5c325207b7 () Int + (+ 2 a b)) + (define-fun a () Int + 0) +)" + result = BooleanSatisfiability.parse_smt_output(output) + @test result == Dict("b" => -2, "a" => 0) + + output = "((define-fun b () Real (- 2.5)) +(define-fun ADD_99dce5c325207b7 () Real +(+ 2 a b)) +(define-fun a () Real +0.0) +))" + result = BooleanSatisfiability.parse_smt_output(output) + @test result == Dict("b" => -2.5, "a" => 0.0) +end + +# Who would do this?? But it's supported anyway. +@testset "Define fully-qualified names" begin + a = Int("a") + ar = Real("a") + hashname = BooleanSatisfiability.__get_hash_name(:ADD, [a, ar]) + @test smt(a + ar) == "(declare-const a Int) +(declare-const a Real) +(define-fun $hashname () Real (+ (as a Int) (as a Real))) +" end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index cc3b795..b78d86f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,9 @@ include("smt_representation_tests.jl") # Calling Z3 and interpreting the result include("solver_interface_tests.jl") +# Test with int and real problems +include("int_parse_tests.jl") + # Extra: Check that defining duplicate variables yields a warning @testset "Duplicate variable warning" begin SET_DUPLICATE_NAME_WARNING!(true)