From 8596bed5ce78bff70223bba6c7c886bc472dcafd Mon Sep 17 00:00:00 2001 From: Janos Gabler Date: Sat, 30 Nov 2024 13:51:38 +0100 Subject: [PATCH] Add more information to InternalOptimizationProblem. --- .../internal_optimization_problem.py | 11 ++++ src/optimagic/optimization/optimize.py | 54 ++++++++++--------- src/optimagic/optimization/process_results.py | 22 ++------ src/optimagic/typing.py | 11 ++++ .../test_internal_optimization_problem.py | 29 ++++++++-- 5 files changed, 78 insertions(+), 49 deletions(-) diff --git a/src/optimagic/optimization/internal_optimization_problem.py b/src/optimagic/optimization/internal_optimization_problem.py index f0951df74..0c9a9fc9b 100644 --- a/src/optimagic/optimization/internal_optimization_problem.py +++ b/src/optimagic/optimization/internal_optimization_problem.py @@ -23,6 +23,7 @@ Direction, ErrorHandling, EvalTask, + ExtraResultFields, PyTree, ) @@ -55,6 +56,7 @@ def __init__( linear_constraints: list[dict[str, Any]] | None, nonlinear_constraints: list[dict[str, Any]] | None, logger: LogStore[Any, Any] | None, + static_result_fields: ExtraResultFields, # TODO: add hess and hessp ): self._fun = fun @@ -73,6 +75,7 @@ def __init__( self._nonlinear_constraints = nonlinear_constraints self._logger = logger self._step_id: int | None = None + self._static_result_fields = static_result_fields # ================================================================================== # Public methods used by optimizers @@ -218,6 +221,14 @@ def bounds(self) -> InternalBounds: def logger(self) -> LogStore[Any, Any] | None: return self._logger + @property + def converter(self) -> Converter: + return self._converter + + @property + def static_result_fields(self) -> ExtraResultFields: + return self._static_result_fields + # ================================================================================== # Implementation of the public functions; The main difference is that the lower- # level implementations return a history entry instead of adding it to the history diff --git a/src/optimagic/optimization/optimize.py b/src/optimagic/optimization/optimize.py index 7935de635..fc1cbdeef 100644 --- a/src/optimagic/optimization/optimize.py +++ b/src/optimagic/optimization/optimize.py @@ -49,7 +49,6 @@ from optimagic.optimization.optimization_logging import log_scheduled_steps_and_get_ids from optimagic.optimization.optimize_result import OptimizeResult from optimagic.optimization.process_results import ( - ExtraResultFields, process_multistart_result, process_single_result, ) @@ -64,6 +63,7 @@ Direction, ErrorHandling, ErrorHandlingLiteral, + ExtraResultFields, NonNegativeFloat, PyTree, ) @@ -543,18 +543,6 @@ def _optimize(problem: OptimizationProblem) -> OptimizeResult: add_soft_bounds=problem.multistart is not None, ) - # ================================================================================== - # initialize the log database - # ================================================================================== - logger: LogStore[Any, Any] | None - - if problem.logging: - logger = LogStore.from_options(problem.logging) - problem_data = ProblemInitialization(problem.direction, problem.params) - logger.problem_store.insert(problem_data) - else: - logger = None - # ================================================================================== # Do some things that require internal parameters or bounds # ================================================================================== @@ -583,12 +571,37 @@ def _optimize(problem: OptimizationProblem) -> OptimizeResult: numdiff_options=problem.numdiff_options, skip_checks=problem.skip_checks, ) + # Define static information that will be added to the OptimizeResult + _scalar_start_criterion = cast( + float, first_crit_eval.internal_value(AggregationLevel.SCALAR) + ) + extra_fields = ExtraResultFields( + start_fun=_scalar_start_criterion, + start_params=problem.params, + algorithm=problem.algorithm.algo_info.name, + direction=problem.direction, + n_free=internal_params.free_mask.sum(), + ) + # create x and internal_bounds x = internal_params.values internal_bounds = InternalBounds( lower=internal_params.lower_bounds, upper=internal_params.upper_bounds, ) + + # ================================================================================== + # initialize the log database + # ================================================================================== + logger: LogStore[Any, Any] | None + + if problem.logging: + logger = LogStore.from_options(problem.logging) + problem_data = ProblemInitialization(problem.direction, problem.params) + logger.problem_store.insert(problem_data) + else: + logger = None + # ================================================================================== # Create a batch evaluator # ================================================================================== @@ -616,6 +629,7 @@ def _optimize(problem: OptimizationProblem) -> OptimizeResult: linear_constraints=None, nonlinear_constraints=internal_nonlinear_constraints, logger=logger, + static_result_fields=extra_fields, ) # ================================================================================== @@ -658,19 +672,6 @@ def _optimize(problem: OptimizationProblem) -> OptimizeResult: # Process the result # ================================================================================== - _scalar_start_criterion = cast( - float, first_crit_eval.internal_value(AggregationLevel.SCALAR) - ) - log_reader: LogReader[Any] | None - - extra_fields = ExtraResultFields( - start_fun=_scalar_start_criterion, - start_params=problem.params, - algorithm=problem.algorithm.algo_info.name, - direction=problem.direction, - n_free=internal_params.free_mask.sum(), - ) - if problem.multistart is None: res = process_single_result( raw_res=raw_res, @@ -686,6 +687,7 @@ def _optimize(problem: OptimizationProblem) -> OptimizeResult: extra_fields=extra_fields, ) + log_reader: LogReader[Any] | None if logger is not None: assert problem.logging is not None log_reader = LogReader.from_options(problem.logging) diff --git a/src/optimagic/optimization/process_results.py b/src/optimagic/optimization/process_results.py index 0817649f5..2843c7325 100644 --- a/src/optimagic/optimization/process_results.py +++ b/src/optimagic/optimization/process_results.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, replace +from dataclasses import replace from typing import Any import numpy as np @@ -7,21 +7,10 @@ from optimagic.optimization.convergence_report import get_convergence_report from optimagic.optimization.optimize_result import MultistartInfo, OptimizeResult from optimagic.parameters.conversion import Converter -from optimagic.typing import AggregationLevel, Direction, PyTree +from optimagic.typing import AggregationLevel, Direction, ExtraResultFields from optimagic.utilities import isscalar -@dataclass(frozen=True) -class ExtraResultFields: - """Fields for OptimizeResult that are not part of InternalOptimizeResult.""" - - start_fun: float - start_params: PyTree - algorithm: str - direction: Direction - n_free: int - - def process_single_result( raw_res: InternalOptimizeResult, converter: Converter, @@ -79,12 +68,7 @@ def process_multistart_result( solver_type: AggregationLevel, extra_fields: ExtraResultFields, ) -> OptimizeResult: - """Process results of internal optimizers. - - Args: - res (dict): Results dictionary of an internal optimizer or multistart optimizer. - - """ + """Process results of internal optimizers.""" if raw_res.multistart_info is None: raise ValueError("Multistart info is missing.") diff --git a/src/optimagic/typing.py b/src/optimagic/typing.py index 889152f79..87385d741 100644 --- a/src/optimagic/typing.py +++ b/src/optimagic/typing.py @@ -156,3 +156,14 @@ class MultiStartIterationHistory(TupleLikeAccess): history: IterationHistory local_histories: list[IterationHistory] | None = None exploration: IterationHistory | None = None + + +@dataclass(frozen=True) +class ExtraResultFields: + """Fields for OptimizeResult that are not part of InternalOptimizeResult.""" + + start_fun: float + start_params: PyTree + algorithm: str + direction: Direction + n_free: int diff --git a/tests/optimagic/optimization/test_internal_optimization_problem.py b/tests/optimagic/optimization/test_internal_optimization_problem.py index a0bb24a25..4c7f1f571 100644 --- a/tests/optimagic/optimization/test_internal_optimization_problem.py +++ b/tests/optimagic/optimization/test_internal_optimization_problem.py @@ -18,11 +18,29 @@ InternalOptimizationProblem, ) from optimagic.parameters.conversion import Converter -from optimagic.typing import AggregationLevel, Direction, ErrorHandling, EvalTask +from optimagic.typing import ( + AggregationLevel, + Direction, + ErrorHandling, + EvalTask, + ExtraResultFields, +) + + +@pytest.fixture +def extra_fields(): + out = ExtraResultFields( + start_fun=100, + start_params=np.arange(3), + algorithm="bla", + direction=Direction.MINIMIZE, + n_free=3, + ) + return out @pytest.fixture -def base_problem(): +def base_problem(extra_fields): """Set up a basic InternalOptimizationProblem that can be modified for tests.""" def fun(params): @@ -72,6 +90,7 @@ def fun_and_jac(params): linear_constraints=linear_constraints, nonlinear_constraints=nonlinear_constraints, logger=None, + static_result_fields=extra_fields, ) return problem @@ -413,7 +432,7 @@ def test_max_problem_exploration_fun(max_problem): @pytest.fixture -def pytree_problem(base_problem): +def pytree_problem(extra_fields): def fun(params): assert isinstance(params, dict) return LeastSquaresFunctionValue(value=params) @@ -479,6 +498,7 @@ def derivative_flatten(tree, x): linear_constraints=linear_constraints, nonlinear_constraints=nonlinear_constraints, logger=None, + static_result_fields=extra_fields, ) return problem @@ -543,7 +563,7 @@ def test_numerical_fun_and_jac_for_pytree_problem(pytree_problem): @pytest.fixture -def error_min_problem(): +def error_min_problem(extra_fields): """Set up a basic InternalOptimizationProblem that can be modified for tests.""" def fun(params): @@ -603,6 +623,7 @@ def fun_and_jac(params): linear_constraints=linear_constraints, nonlinear_constraints=nonlinear_constraints, logger=None, + static_result_fields=extra_fields, ) return problem