Skip to content

Commit

Permalink
Merge pull request #3626 from JuliaReach/schillic/SymEngine
Browse files Browse the repository at this point in the history
Share common `SymEngine` code
  • Loading branch information
schillic authored Jul 27, 2024
2 parents a5a4046 + 5a085d4 commit 623a947
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 81 deletions.
15 changes: 15 additions & 0 deletions src/Initialization/init_SymEngine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ julia> free_symbols(:(x1 + x2 <= 2*x4 + 6), HalfSpace)
"""
function free_symbols(::Expr, ::Type{<:LazySet}) end # COV_EXCL_LINE

# parse `a` and `b` from `a1 x1 + ... + an xn + K [cmp] 0` for [cmp] in {<, <=, =, >, >=}
function _parse_linear_expression(linexpr::Basic, vars::Vector{Basic}, N)
if isempty(vars)
vars = SymEngine.free_symbols(linexpr)
end
b = SymEngine.subs(linexpr, [vi => zero(N) for vi in vars]...)
a = convert(Basic, linexpr - b)

# convert to correct numeric type
a = convert(Vector{N}, diff.(a, vars))
b = convert(N, b)

return a, b
end

# Note: this convenience function is not used anywhere
function _free_symbols(expr::Expr)
if _is_hyperplane(expr)
Expand Down
9 changes: 4 additions & 5 deletions src/Interfaces/AbstractPolyhedron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,10 @@ function _linear_map_hrep_helper(M::AbstractMatrix, P::LazySet,
return HPolyhedron(constraints)
end

# internal function; defined here due to dependency SymEngine and submodules
function _is_halfspace() end

# internal function; defined here due to dependency SymEngine and submodules
function _is_hyperplane() end
# internal functions; defined here due to dependency SymEngine and submodules
function _is_halfspace end
function _is_hyperplane end
function _parse_linear_expression end

# To account for the compilation order, other functions are defined in the file
# AbstractPolyhedron_functions.jl
38 changes: 13 additions & 25 deletions src/Sets/HalfSpace/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ end
function load_SymEngine_convert_HalfSpace()
return quote
using .SymEngine: Basic
using ..LazySets: _parse_linear_expression

"""
convert(::Type{HalfSpace{N}}, expr::Expr; vars=nothing) where {N}
Expand Down Expand Up @@ -53,31 +54,18 @@ function load_SymEngine_convert_HalfSpace()
```
"""
function convert(::Type{HalfSpace{N}}, expr::Expr; vars::Vector{Basic}=Basic[]) where {N}
@assert _is_halfspace(expr) "the expression :(expr) does not correspond to a half-space"

# check sense of the inequality, assuming < or <= by default
got_geq = expr.args[1] in [:(>=), :(>)]

# get sides of the inequality
lhs, rhs = convert(Basic, expr.args[2]), convert(Basic, expr.args[3])

# a1 x1 + ... + an xn + K [cmp] 0 for cmp in <, <=, >, >=
eq = lhs - rhs
if isempty(vars)
vars = SymEngine.free_symbols(eq)
end
K = SymEngine.subs(eq, [vi => zero(N) for vi in vars]...)
a = convert(Basic, eq - K)

# convert to numeric types
K = convert(N, K)
a = convert(Vector{N}, diff.(a, vars))

if got_geq
return HalfSpace(-a, K)
else
return HalfSpace(a, -K)
end
@assert _is_halfspace(expr) "the expression $expr does not correspond to a half-space"

# convert to SymEngine expressions
linexpr, cmp = _parse_halfspace(expr)

# check sense of the inequality, assuming < or <= by default (checked before)
got_geq = cmp (:(>=), :(>))

# `a1 x1 + ... + an xn + b [cmp] 0` for [cmp] ∈ {<, <=, >, >=}
a, b = _parse_linear_expression(linexpr, vars, N)

return got_geq ? HalfSpace(-a, b) : HalfSpace(a, -b)
end

# type-less default half-space conversion
Expand Down
12 changes: 9 additions & 3 deletions src/Sets/HalfSpace/init_SymEngine.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import .SymEngine: free_symbols

function _parse_halfspace(expr::Expr)
lhs = convert(SymEngine.Basic, expr.args[2])
rhs = convert(SymEngine.Basic, expr.args[3])
cmp = expr.args[1]
return (lhs - rhs, cmp)
end

function free_symbols(expr::Expr, ::Type{<:HalfSpace})
# get sides of the inequality
lhs, rhs = convert(SymEngine.Basic, expr.args[2]), convert(SymEngine.Basic, expr.args[3])
return SymEngine.free_symbols(lhs - rhs)
linexpr, _ = _parse_halfspace(expr)
return SymEngine.free_symbols(linexpr)
end

eval(load_SymEngine_ishalfspace())
Expand Down
17 changes: 7 additions & 10 deletions src/Sets/HalfSpace/ishalfspace.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
function load_SymEngine_ishalfspace()
return quote
using .SymEngine: Basic
import .SymEngine: free_symbols
using ..LazySets: _is_linearcombination

"""
Expand Down Expand Up @@ -45,23 +43,22 @@ function load_SymEngine_ishalfspace()
```
"""
function _is_halfspace(expr::Expr)::Bool

# check that there are three arguments
# these are the comparison symbol, the left hand side and the right hand side
# check that there are three arguments:
# the comparison symbol, the left-hand side and the right-hand side
if (length(expr.args) != 3) || !(expr.head == :call)
return false
end

# convert to SymEngine expression
linexpr, cmp = _parse_halfspace(expr)

# check that this is an inequality
if !(expr.args[1] in [:(<=), :(<), :(>=), :(>)])
if cmp [:(<=), :(<), :(>=), :(>)]
return false
end

# convert to symengine expressions
lhs, rhs = convert(Basic, expr.args[2]), convert(Basic, expr.args[3])

# check if the expression defines a half-space
return _is_linearcombination(lhs) && _is_linearcombination(rhs)
return _is_linearcombination(linexpr)
end
end
end # load_SymEngine_ishalfspace
Expand Down
26 changes: 7 additions & 19 deletions src/Sets/Hyperplane/convert.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
function load_SymEngine_convert_Hyperplane()
return quote
using .SymEngine: Basic
using ..LazySets: _parse_linear_expression

"""
convert(::Type{Hyperplane{N}}, expr::Expr; vars=nothing) where {N}
Expand Down Expand Up @@ -41,28 +42,15 @@ function load_SymEngine_convert_Hyperplane()
```
"""
function convert(::Type{Hyperplane{N}}, expr::Expr; vars::Vector{Basic}=Basic[]) where {N}
@assert _is_hyperplane(expr) "the expression :(expr) does not correspond to a Hyperplane"
@assert _is_hyperplane(expr) "the expression $expr does not correspond to a Hyperplane"

# get sides of the inequality
lhs = convert(Basic, expr.args[1])
# convert to SymEngine expression
linexpr = _parse_hyperplane(expr)

# treats the 4 in :(2*x1 = 4)
rhs = :args in fieldnames(typeof(expr.args[2])) ? convert(Basic, expr.args[2].args[2]) :
convert(Basic, expr.args[2])
# a1 x1 + ... + an xn + b = 0
a, b = _parse_linear_expression(linexpr, vars, N)

# a1 x1 + ... + an xn + K = 0
eq = lhs - rhs
if isempty(vars)
vars = SymEngine.free_symbols(eq)
end
K = SymEngine.subs(eq, [vi => zero(N) for vi in vars]...)
a = convert(Basic, eq - K)

# convert to numeric types
K = convert(N, K)
a = convert(Vector{N}, diff.(a, vars))

return Hyperplane(a, -K)
return Hyperplane(a, -b)
end

# type-less default Hyperplane conversion
Expand Down
12 changes: 7 additions & 5 deletions src/Sets/Hyperplane/init_SymEngine.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import .SymEngine: free_symbols

function free_symbols(expr::Expr, ::Type{<:Hyperplane})
# get sides of the equality
function _parse_hyperplane(expr::Expr)
lhs = convert(SymEngine.Basic, expr.args[1])

# treats the 4 in :(2*x1 = 4)
rhs = :args in fieldnames(typeof(expr.args[2])) ?
convert(SymEngine.Basic, expr.args[2].args[2]) :
convert(SymEngine.Basic, expr.args[2])
return SymEngine.free_symbols(lhs - rhs)
return lhs - rhs
end

function free_symbols(expr::Expr, ::Type{<:Hyperplane})
linexpr = _parse_hyperplane(expr)
return SymEngine.free_symbols(linexpr)
end

eval(load_SymEngine_ishyperplanar())
Expand Down
19 changes: 5 additions & 14 deletions src/Sets/Hyperplane/is_hyperplanar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ end

function load_SymEngine_ishyperplanar()
return quote
using .SymEngine: Basic
using ..LazySets: _is_linearcombination

"""
Expand Down Expand Up @@ -48,25 +47,17 @@ function load_SymEngine_ishyperplanar()
```
"""
function _is_hyperplane(expr::Expr)::Bool

# check that there are three arguments
# these are the comparison symbol, the left hand side and the right hand side
# check that the head is `=` and there are two arguments:
# the left-hand side and the right-hand side
if (length(expr.args) != 2) || !(expr.head == :(=))
return false
end

# convert to symengine expressions
lhs = convert(Basic, expr.args[1])

if :args in fieldnames(typeof(expr.args[2]))
# treats the 4 in :(2*x1 = 4)
rhs = convert(Basic, expr.args[2].args[2])
else
rhs = convert(Basic, expr.args[2])
end
# convert to SymEngine expression
linexpr = _parse_hyperplane(expr)

# check if the expression defines a hyperplane
return _is_linearcombination(lhs) && _is_linearcombination(rhs)
return _is_linearcombination(linexpr)
end
end
end # load_SymEngine_ishyperplanar
Expand Down

0 comments on commit 623a947

Please sign in to comment.