From f72ffeea25f7cfded284d9dc5eb1a3440b76ab15 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 17 Dec 2024 13:05:53 -0700 Subject: [PATCH] fix cfg_vanilla --- mpisppy/utils/cfg_vanilla.py | 41 ++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/mpisppy/utils/cfg_vanilla.py b/mpisppy/utils/cfg_vanilla.py index 4c1fe84b..e7fa6376 100644 --- a/mpisppy/utils/cfg_vanilla.py +++ b/mpisppy/utils/cfg_vanilla.py @@ -389,6 +389,7 @@ def _PHBase_spoke_foundation( rho_setter=None, all_nodenames=None, ph_extensions=None, + extension_kwargs=None, ): # only the shared options shoptions = shared_options(cfg) @@ -410,6 +411,8 @@ def _PHBase_spoke_foundation( spoke_dict["opt_kwargs"]["rho_setter"] = rho_setter if ph_extensions is not None: spoke_dict["opt_kwargs"]["extensions"] = ph_extensions + if extension_kwargs is not None: + spoke_dict["opt_kwargs"]["extension_kwargs"] = extension_kwargs return spoke_dict @@ -423,6 +426,7 @@ def _Xhat_Eval_spoke_foundation( rho_setter=None, all_nodenames=None, ph_extensions=None, + extension_kwargs=None, ): spoke_dict = _PHBase_spoke_foundation( spoke_class, @@ -433,11 +437,10 @@ def _Xhat_Eval_spoke_foundation( scenario_creator_kwargs=scenario_creator_kwargs, rho_setter=rho_setter, all_nodenames=all_nodenames, - ph_extensions=ph_extensions) + ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, + ) spoke_dict["opt_class"] = Xhat_Eval - if ph_extensions is not None: - spoke_dict["opt_kwargs"]["ph_extensions"] = ph_extensions - del spoke_dict["opt_kwargs"]["extensions"] # ph_extensions in Xhat_Eval return spoke_dict @@ -449,6 +452,8 @@ def lagrangian_spoke( scenario_creator_kwargs=None, rho_setter=None, all_nodenames=None, + ph_extensions=None, + extension_kwargs=None, ): lagrangian_spoke = _PHBase_spoke_foundation( LagrangianOuterBound, @@ -459,6 +464,8 @@ def lagrangian_spoke( scenario_creator_kwargs=scenario_creator_kwargs, rho_setter=rho_setter, all_nodenames=all_nodenames, + ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) if cfg.lagrangian_iter0_mipgap is not None: lagrangian_spoke["opt_kwargs"]["options"]["iter0_solver_options"]\ @@ -479,6 +486,8 @@ def reduced_costs_spoke( scenario_creator_kwargs=None, rho_setter=None, all_nodenames=None, + ph_extensions=None, + extension_kwargs=None, ): rc_spoke = _PHBase_spoke_foundation( ReducedCostsSpoke, @@ -489,6 +498,8 @@ def reduced_costs_spoke( scenario_creator_kwargs=scenario_creator_kwargs, rho_setter=rho_setter, all_nodenames=all_nodenames, + ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) add_ph_tracking(rc_spoke, cfg, spoke=True) @@ -506,6 +517,8 @@ def lagranger_spoke( scenario_creator_kwargs=None, rho_setter=None, all_nodenames = None, + ph_extensions=None, + extension_kwargs=None, ): lagranger_spoke = _PHBase_spoke_foundation( LagrangerOuterBound, @@ -516,6 +529,8 @@ def lagranger_spoke( scenario_creator_kwargs=scenario_creator_kwargs, rho_setter=rho_setter, all_nodenames=all_nodenames, + ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) if cfg.lagranger_iter0_mipgap is not None: lagranger_spoke["opt_kwargs"]["options"]["iter0_solver_options"]\ @@ -539,6 +554,8 @@ def subgradient_spoke( scenario_creator_kwargs=None, rho_setter=None, all_nodenames=None, + ph_extensions=None, + extension_kwargs=None, ): subgradient_spoke = _PHBase_spoke_foundation( SubgradientOuterBound, @@ -549,6 +566,8 @@ def subgradient_spoke( scenario_creator_kwargs=scenario_creator_kwargs, rho_setter=rho_setter, all_nodenames=all_nodenames, + ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) if cfg.subgradient_iter0_mipgap is not None: subgradient_spoke["opt_kwargs"]["options"]["iter0_solver_options"]\ @@ -571,6 +590,7 @@ def xhatlooper_spoke( all_scenario_names, scenario_creator_kwargs=None, ph_extensions=None, + extension_kwargs=None, ): xhatlooper_dict = _Xhat_Eval_spoke_foundation( @@ -580,7 +600,8 @@ def xhatlooper_spoke( scenario_denouement, all_scenario_names, scenario_creator_kwargs=scenario_creator_kwargs, - ph_extensions=ph_extensions, + extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) xhatlooper_dict["opt_kwargs"]["options"]['bundles_per_rank'] = 0 # no bundles for xhat @@ -602,6 +623,7 @@ def xhatxbar_spoke( scenario_creator_kwargs=None, variable_probability=None, ph_extensions=None, + extension_kwargs=None, all_nodenames=None, ): xhatxbar_dict = _Xhat_Eval_spoke_foundation( @@ -612,6 +634,7 @@ def xhatxbar_spoke( all_scenario_names, scenario_creator_kwargs=scenario_creator_kwargs, ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, all_nodenames=all_nodenames, ) @@ -635,6 +658,7 @@ def xhatshuffle_spoke( all_nodenames=None, scenario_creator_kwargs=None, ph_extensions=None, + extension_kwargs=None, ): xhatshuffle_dict = _Xhat_Eval_spoke_foundation( @@ -646,6 +670,7 @@ def xhatshuffle_spoke( all_nodenames=all_nodenames, scenario_creator_kwargs=scenario_creator_kwargs, ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) xhatshuffle_dict["opt_kwargs"]["options"]['bundles_per_rank'] = 0 # no bundles for xhat xhatshuffle_dict["opt_kwargs"]["options"]["xhat_looper_options"] = { @@ -669,7 +694,8 @@ def xhatspecific_spoke( scenario_dict, all_nodenames=None, scenario_creator_kwargs=None, - ph_extensions=None, + ph_extensions=None, + extension_kwargs=None, ): xhatspecific_dict = _Xhat_Eval_spoke_foundation( @@ -680,6 +706,7 @@ def xhatspecific_spoke( all_scenario_names, scenario_creator_kwargs=scenario_creator_kwargs, ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) xhatspecific_dict["opt_kwargs"]["options"]['bundles_per_rank'] = 0 # no bundles for xhat return xhatspecific_dict @@ -691,6 +718,7 @@ def xhatlshaped_spoke( all_scenario_names, scenario_creator_kwargs=None, ph_extensions=None, + extension_kwargs=None, ): xhatlshaped_dict = _Xhat_Eval_spoke_foundation( @@ -701,6 +729,7 @@ def xhatlshaped_spoke( all_scenario_names, scenario_creator_kwargs=scenario_creator_kwargs, ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) xhatlshaped_dict["opt_kwargs"]["options"]['bundles_per_rank'] = 0 # no bundles for xhat