Skip to content

Commit

Permalink
Merge pull request #2228 from JuliaOpt/bl/constraint_head
Browse files Browse the repository at this point in the history
Add parse_constraint_expr and parse_constraint_head
  • Loading branch information
blegat authored May 4, 2020
2 parents a53bd0b + a6f7948 commit a72e7ac
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/indicator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function parse_one_operator_constraint(
_error("Invalid right-hand side `$(rhs)` of indicator constraint. Expected constraint surrounded by `{` and `}`.")
end
rhs_con = rhs.args[1]
rhs_vectorized, rhs_parsecode, rhs_buildcall = parse_constraint(_error, rhs_con.args...)
rhs_vectorized, rhs_parsecode, rhs_buildcall = parse_constraint_expr(_error, rhs_con)
if vectorized != rhs_vectorized
_error("Inconsistent use of `.` in symbols to indicate vectorization.")
end
Expand Down
33 changes: 27 additions & 6 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ function parse_one_operator_constraint(_error::Function, vectorized::Bool, sense
return parse_code, _build_call(_error, vectorized, :(_functionize($variable)), set)
end

function parse_constraint_expr(_error::Function, expr::Expr)
return parse_constraint_head(_error, Val(expr.head), expr.args...)
end
function parse_constraint_head(_error::Function, ::Val{:call}, args...)
return parse_constraint(_error, args...)
end

function parse_constraint(_error::Function, sense::Symbol, lhs, rhs)
(sense, vectorized) = _check_vectorized(sense)
vectorized, parse_one_operator_constraint(_error, vectorized, Val(sense), lhs, rhs)...
Expand All @@ -202,6 +209,10 @@ function parse_ternary_constraint(_error::Function, args...)
_error("Only two-sided rows of the form lb <= expr <= ub or ub >= expr >= lb are supported.")
end

function parse_constraint_head(_error::Function, ::Val{:comparison}, lb, lsign::Symbol, aff, rsign::Symbol, ub)
return parse_constraint(_error, lb, lsign, aff, rsign, ub)
end

function parse_constraint(_error::Function, lb, lsign::Symbol, aff, rsign::Symbol, ub)
(lsign, lvectorized) = _check_vectorized(lsign)
(rsign, rvectorized) = _check_vectorized(rsign)
Expand All @@ -215,13 +226,20 @@ function parse_constraint(_error::Function, lb, lsign::Symbol, aff, rsign::Symbo
vectorized, parsecode, buildcall
end

function parse_constraint(_error::Function, args...)
function _unknown_constraint_expr(_error::Function)
# Unknown
_error("Constraints must be in one of the following forms:\n" *
" expr1 <= expr2\n" * " expr1 >= expr2\n" *
" expr1 == expr2\n" * " lb <= expr <= ub")
end

function parse_constraint_head(_error::Function, ::Val, args...)
_unknown_constraint_expr(_error)
end
function parse_constraint(_error::Function, args...)
_unknown_constraint_expr(_error)
end

# Generic fallback.
function build_constraint(_error::Function, func,
set::Union{MOI.AbstractScalarSet, MOI.AbstractVectorSet})
Expand Down Expand Up @@ -374,7 +392,7 @@ function _constraint_macro(args, macro_name::Symbol, parsefun::Function)
# in a function returning `ConstraintRef`s and give it to `Containers.container`.
idxvars, indices = Containers._build_ref_sets(_error, c)

vectorized, parsecode, buildcall = parsefun(_error, x.args...)
vectorized, parsecode, buildcall = parsefun(_error, x)
_add_kw_args(buildcall, kw_args)
if vectorized
# TODO: Pass through names here.
Expand Down Expand Up @@ -457,9 +475,12 @@ that either `func` or `set` will be some custom type, rather than e.g. a
set appearing in the constraint.
"""
macro constraint(args...)
_constraint_macro(args, :constraint, parse_constraint)
_constraint_macro(args, :constraint, parse_constraint_expr)
end

function parse_SD_constraint_expr(_error::Function, expr::Expr)
return parse_SD_constraint(_error, expr.args...)
end
function parse_SD_constraint(_error::Function, sense::Symbol, lhs, rhs)
# Simple comparison - move everything to the LHS
aff = :()
Expand Down Expand Up @@ -554,7 +575,7 @@ part of the matrix assuming that it is symmetric, see [`PSDCone`](@ref) to see
how to use it.
"""
macro SDconstraint(args...)
_constraint_macro(args, :SDconstraint, parse_SD_constraint)
_constraint_macro(args, :SDconstraint, parse_SD_constraint_expr)
end

"""
Expand Down Expand Up @@ -585,8 +606,8 @@ macro build_constraint(constraint_expr)
"Are you missing a comparison (<=, >=, or ==)?")
end

is_vectorized, parse_code, build_call = parse_constraint(
_error, constraint_expr.args...)
is_vectorized, parse_code, build_call = parse_constraint_expr(
_error, constraint_expr)
result_variable = gensym()
code = quote
$parse_code
Expand Down
23 changes: 23 additions & 0 deletions test/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,27 @@ function build_constraint_keyword_test(ModelType::Type{<:JuMP.AbstractModel})
end
end

struct CustomType
end
function JuMP.parse_constraint_head(_error::Function, ::Val{:(:=)}, lhs, rhs)
return false, :(), :(build_constraint($_error, $(esc(lhs)), $(esc(rhs))))
end
struct CustomSet <: MOI.AbstractScalarSet
end
function JuMP.build_constraint(_error::Function, func, ::CustomType)
JuMP.build_constraint(_error, func, CustomSet())
end
function custom_expression_test(ModelType::Type{<:JuMP.AbstractModel})
@testset "Custom expression" begin
model = ModelType()
@variable(model, x)
@constraint(model, con_ref, x := CustomType())
con = JuMP.constraint_object(con_ref)
@test jump_function(con) == x
@test moi_set(con) isa CustomSet
end
end

function macros_test(ModelType::Type{<:JuMP.AbstractModel}, VariableRefType::Type{<:JuMP.AbstractVariableRef})
@testset "build_constraint on variable" begin
m = ModelType()
Expand Down Expand Up @@ -336,6 +357,8 @@ function macros_test(ModelType::Type{<:JuMP.AbstractModel}, VariableRefType::Typ
end

build_constraint_keyword_test(ModelType)

custom_expression_test(ModelType)
end

@testset "Macros for JuMP.Model" begin
Expand Down

0 comments on commit a72e7ac

Please sign in to comment.