Skip to content

Commit

Permalink
shared axes code
Browse files Browse the repository at this point in the history
  • Loading branch information
jd-lara committed Mar 7, 2024
1 parent ebcb423 commit 4893734
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions src/problems/multi_region_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,22 @@ function PSI.DecisionModel{MultiRegionProblem}(
)
end

function _get_axes!(common_axes::Dict, container::PSI.OptimizationContainer)
function _join_axes!(axes_data::SortedDict{Int, Set}, ix::Int, axes_value::UnitRange{Int})
_axes_data = get!(axes_data, ix, Set{UnitRange{Int}}())
if _axes_data == axes_value
return
end
union!(_axes_data, [axes_value])
return
end

function _join_axes!(axes_data::SortedDict{Int, Set}, ix::Int, axes_value::Vector)
_axes_data = get!(axes_data, ix, Set{eltype(axes_value)}())
union!(_axes_data, axes_value)
return
end

function _get_axes!(common_axes::Dict{Symbol, Dict{PSI.OptimizationContainerKey, SortedDict{Int, Set}}}, container::PSI.OptimizationContainer)
for field in CONTAINER_FIELDS
field_data = getfield(container, field)
for (key, value_container) in field_data
Expand All @@ -41,30 +56,38 @@ function _get_axes!(common_axes::Dict, container::PSI.OptimizationContainer)
end
axes_data = get!(common_axes[field], key, SortedDict{Int, Set}())
for (ix, vals) in enumerate(axes(value_container))
union!(get!(axes_data, ix, Set(vals)))
_join_axes!(axes_data, ix, vals)
end
end
end
return
end

function _make_joint_axes!(dim1::Set{T}, dim2::Set{UnitRange{Int}}) where T <: Union{Int, String}
return (collect(dim1), first(dim2))
end

function _make_joint_axes!(dim1::Set{UnitRange{Int}})
return (first(dim1),)
end


function _map_containers(model::PSI.DecisionModel{MultiRegionProblem})
common_axes =
Dict(key => Dict{PSI.OptimizationContainerKey, Any}() for key in CONTAINER_FIELDS)
Dict{Symbol, Dict{PSI.OptimizationContainerKey, SortedDict{Int, Set}}}(key => Dict{PSI.OptimizationContainerKey, SortedDict{Int, Set}}() for key in CONTAINER_FIELDS)
container = PSI.get_optimization_container(model)
for subproblem in values(container.subproblems)
for (_, subproblem) in container.subproblems
_get_axes!(common_axes, subproblem)
end

for (field, vals) in common_axes
field_data = getfield(container, field)
field_data = getproperty(container, field)
for (key, axes_data) in vals
ax = [sort!(collect(v)) for v in values(axes_data)]
ax = _make_joint_axes!(collect(values(axes_data))...)
field_data[key] =
PSI.remove_undef!(JuMP.Containers.DenseAxisArray{Float64}(undef, ax...))
end
end

#TODO: Parameters Requires a different approach

return
Expand Down

0 comments on commit 4893734

Please sign in to comment.