Skip to content

Commit

Permalink
Fix primal_feasibility_report with non-Float64 number types (#3913)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Jan 15, 2025
1 parent c21bedc commit b88dac4
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/feasibility_checker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,48 @@ function _add_infeasible_constraints(
return
end

function _add_infeasible_constraints(
model::GenericModel{T},
::Type{F},
::Type{S},
violated_constraints::Dict{Any,T},
point_f::Function,
atol::T,
) where {T,F<:GenericNonlinearExpr,S}
for con in all_constraints(model, F, S)
obj = constraint_object(con)
# value(::GenericNonlinearExpr) returns `Float64`. Convert it to `T` for
# the case where the model is a different number type.
fn_value = convert(T, value(point_f, obj.func))
d = _distance_to_set(fn_value, obj.set, T)
if d > atol
violated_constraints[con] = d
end
end
return
end

function _add_infeasible_constraints(
model::GenericModel{T},
::Type{F},
::Type{S},
violated_constraints::Dict{Any,T},
point_f::Function,
atol::T,
) where {T,F<:Vector{<:GenericNonlinearExpr},S}
for con in all_constraints(model, F, S)
obj = constraint_object(con)
# value(::GenericNonlinearExpr) returns `Float64`. Convert it to `T` for
# the case where the model is a different number type.
fn_value = convert(Vector{T}, value.(point_f, obj.func))
d = _distance_to_set(fn_value, obj.set, T)
if d > atol
violated_constraints[con] = d
end
end
return
end

function _add_infeasible_nonlinear_constraints(
model::GenericModel{T},
violated_constraints::Dict{Any,T},
Expand Down
18 changes: 18 additions & 0 deletions test/test_feasibility_checker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,22 @@ function test_nonlinear_missing()
)
end

function test_nonlinear_Float32()
model = GenericModel{Float32}()
@variable(model, x >= 2.0f0)
@constraint(model, c1, x == 1.0f0)
@constraint(model, c2, x * x == 0.0f0)
@constraint(model, c3, cos(x) == 1.0f0)
@constraint(model, c4, sqrt(x) == 1.0f0)
@constraint(model, c5, [-2 + sqrt(x), 1.0f0] in Nonnegatives())
report = primal_feasibility_report(model, Dict(x => 2.0f0))
@test length(report) == 5
@test report[c1] === 2.0f0 - 1.0f0
@test report[c2] === 2.0f0^2
@test report[c3] === 1.0f0 - cos(2.0f0)
@test report[c4] === sqrt(2.0f0) - 1.0f0
@test report[c5] === 2.0f0 - sqrt(2.0f0)
return
end

end # module

0 comments on commit b88dac4

Please sign in to comment.