From 4893734fbbbacd219452e7e92d010648d1b2f527 Mon Sep 17 00:00:00 2001 From: Jose Daniel Lara Date: Thu, 7 Mar 2024 11:41:29 -0700 Subject: [PATCH] shared axes code --- src/problems/multi_region_problem.jl | 37 ++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/src/problems/multi_region_problem.jl b/src/problems/multi_region_problem.jl index 8cff855..bb9dfaa 100644 --- a/src/problems/multi_region_problem.jl +++ b/src/problems/multi_region_problem.jl @@ -32,7 +32,22 @@ function PSI.DecisionModel{MultiRegionProblem}( ) end -function _get_axes!(common_axes::Dict, container::PSI.OptimizationContainer) +function _join_axes!(axes_data::SortedDict{Int, Set}, ix::Int, axes_value::UnitRange{Int}) + _axes_data = get!(axes_data, ix, Set{UnitRange{Int}}()) + if _axes_data == axes_value + return + end + union!(_axes_data, [axes_value]) + return +end + +function _join_axes!(axes_data::SortedDict{Int, Set}, ix::Int, axes_value::Vector) + _axes_data = get!(axes_data, ix, Set{eltype(axes_value)}()) + union!(_axes_data, axes_value) + return +end + +function _get_axes!(common_axes::Dict{Symbol, Dict{PSI.OptimizationContainerKey, SortedDict{Int, Set}}}, container::PSI.OptimizationContainer) for field in CONTAINER_FIELDS field_data = getfield(container, field) for (key, value_container) in field_data @@ -41,30 +56,38 @@ function _get_axes!(common_axes::Dict, container::PSI.OptimizationContainer) end axes_data = get!(common_axes[field], key, SortedDict{Int, Set}()) for (ix, vals) in enumerate(axes(value_container)) - union!(get!(axes_data, ix, Set(vals))) + _join_axes!(axes_data, ix, vals) end end end return end +function _make_joint_axes!(dim1::Set{T}, dim2::Set{UnitRange{Int}}) where T <: Union{Int, String} + return (collect(dim1), first(dim2)) +end + +function _make_joint_axes!(dim1::Set{UnitRange{Int}}) + return (first(dim1),) +end + + function _map_containers(model::PSI.DecisionModel{MultiRegionProblem}) common_axes = - Dict(key => Dict{PSI.OptimizationContainerKey, Any}() for key in CONTAINER_FIELDS) + Dict{Symbol, Dict{PSI.OptimizationContainerKey, SortedDict{Int, Set}}}(key => Dict{PSI.OptimizationContainerKey, SortedDict{Int, Set}}() for key in CONTAINER_FIELDS) container = PSI.get_optimization_container(model) - for subproblem in values(container.subproblems) + for (_, subproblem) in container.subproblems _get_axes!(common_axes, subproblem) end for (field, vals) in common_axes - field_data = getfield(container, field) + field_data = getproperty(container, field) for (key, axes_data) in vals - ax = [sort!(collect(v)) for v in values(axes_data)] + ax = _make_joint_axes!(collect(values(axes_data))...) field_data[key] = PSI.remove_undef!(JuMP.Containers.DenseAxisArray{Float64}(undef, ax...)) end end - #TODO: Parameters Requires a different approach return