From 4dca0bc484d0f6a9ce6125e029f955ec57c1b361 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 28 Oct 2024 12:41:47 +0100 Subject: [PATCH] Allow for AlgorithmType in estimation functions (#535) --- src/estimagic/estimate_ml.py | 9 ++++-- src/estimagic/estimate_msm.py | 9 ++++-- src/optimagic/shared/check_option_dicts.py | 2 +- tests/estimagic/test_estimate_ml.py | 29 ++++++++++++++++++ tests/estimagic/test_estimate_msm.py | 35 ++++++++++++++++++++++ 5 files changed, 77 insertions(+), 7 deletions(-) diff --git a/src/estimagic/estimate_ml.py b/src/estimagic/estimate_ml.py index 39383c6a4..4044fee1d 100644 --- a/src/estimagic/estimate_ml.py +++ b/src/estimagic/estimate_ml.py @@ -95,8 +95,8 @@ def estimate_ml( optimize_options to False. Pytrees can be a numpy array, a pandas Series, a DataFrame with "value" column, a float and any kind of (nested) dictionary or list containing these elements. See :ref:`params` for examples. - optimize_options (dict, str or False): Keyword arguments that govern the - numerical optimization. Valid entries are all arguments of + optimize_options (dict, Algorithm, str or False): Keyword arguments that govern + the numerical optimization. Valid entries are all arguments of :func:`~estimagic.optimization.optimize.minimize` except for those that are passed explicilty to ``estimate_ml``. If you pass False as optimize_options you signal that ``params`` are already the optimal parameters and no @@ -199,7 +199,10 @@ def estimate_ml( is_optimized = optimize_options is False if not is_optimized: - if isinstance(optimize_options, str): + # If optimize_options is not a dictionary and not False, we assume it represents + # an algorithm. The actual testing of whether it is a valid algorithm is done + # when `maximize` is called. + if not isinstance(optimize_options, dict): optimize_options = {"algorithm": optimize_options} check_optimization_options( diff --git a/src/estimagic/estimate_msm.py b/src/estimagic/estimate_msm.py index 57430bb99..b207e9829 100644 --- a/src/estimagic/estimate_msm.py +++ b/src/estimagic/estimate_msm.py @@ -107,8 +107,8 @@ def estimate_msm( optimize_options to False. Pytrees can be a numpy array, a pandas Series, a DataFrame with "value" column, a float and any kind of (nested) dictionary or list containing these elements. See :ref:`params` for examples. - optimize_options (dict, str or False): Keyword arguments that govern the - numerical optimization. Valid entries are all arguments of + optimize_options (dict, Algorithm, str or False): Keyword arguments that govern + the numerical optimization. Valid entries are all arguments of :func:`~estimagic.optimization.optimize.minimize` except for those that can be passed explicitly to ``estimate_msm``. If you pass False as ``optimize_options`` you signal that ``params`` are already @@ -199,7 +199,10 @@ def estimate_msm( is_optimized = optimize_options is False if not is_optimized: - if isinstance(optimize_options, str): + # If optimize_options is not a dictionary and not False, we assume it represents + # an algorithm. The actual testing of whether it is a valid algorithm is done + # when `minimize` is called. + if not isinstance(optimize_options, dict): optimize_options = {"algorithm": optimize_options} check_optimization_options( diff --git a/src/optimagic/shared/check_option_dicts.py b/src/optimagic/shared/check_option_dicts.py index 82ace0201..c4c45fcc7 100644 --- a/src/optimagic/shared/check_option_dicts.py +++ b/src/optimagic/shared/check_option_dicts.py @@ -41,6 +41,6 @@ def check_optimization_options(options, usage, algorithm_mandatory=True): msg = ( "The following are not valid entries of optimize_options because they are " "not only relevant for minimization but also for inference: " - "{invalid_general}" + f"{invalid_general}" ) raise ValueError(msg) diff --git a/tests/estimagic/test_estimate_ml.py b/tests/estimagic/test_estimate_ml.py index 01714bee8..f3e806311 100644 --- a/tests/estimagic/test_estimate_ml.py +++ b/tests/estimagic/test_estimate_ml.py @@ -18,6 +18,7 @@ scalar_logit_fun_and_jac, ) from optimagic import mark +from optimagic.optimizers import scipy_optimizers from optimagic.parameters.bounds import Bounds @@ -349,6 +350,34 @@ def test_estimate_ml_optimize_options_false(fitted_logit_model, logit_np_inputs) aaae(got.cov(method="jacobian"), fitted_logit_model.covjac, decimal=4) +def test_estimate_ml_algorithm_type(logit_np_inputs): + """Test that estimate_ml computes correct covariances given correct params.""" + kwargs = {"y": logit_np_inputs["y"], "x": logit_np_inputs["x"]} + + params = pd.DataFrame({"value": logit_np_inputs["params"]}) + + estimate_ml( + loglike=logit_loglike, + params=params, + loglike_kwargs=kwargs, + optimize_options=scipy_optimizers.ScipyLBFGSB, + ) + + +def test_estimate_ml_algorithm(logit_np_inputs): + """Test that estimate_ml computes correct covariances given correct params.""" + kwargs = {"y": logit_np_inputs["y"], "x": logit_np_inputs["x"]} + + params = pd.DataFrame({"value": logit_np_inputs["params"]}) + + estimate_ml( + loglike=logit_loglike, + params=params, + loglike_kwargs=kwargs, + optimize_options=scipy_optimizers.ScipyLBFGSB(stopping_maxfun=10), + ) + + # ====================================================================================== # Univariate normal case using dict params # ====================================================================================== diff --git a/tests/estimagic/test_estimate_msm.py b/tests/estimagic/test_estimate_msm.py index 0c7a2d775..2684513d9 100644 --- a/tests/estimagic/test_estimate_msm.py +++ b/tests/estimagic/test_estimate_msm.py @@ -10,6 +10,7 @@ from estimagic.estimate_msm import estimate_msm from optimagic.optimization.optimize_result import OptimizeResult +from optimagic.optimizers import scipy_optimizers from optimagic.shared.check_option_dicts import ( check_optimization_options, ) @@ -161,6 +162,40 @@ def test_estimate_msm_with_jacobian(): aaae(calculated.cov(), cov_np) +def test_estimate_msm_with_algorithm_type(): + start_params = np.array([3, 2, 1]) + expected_params = np.zeros(3) + empirical_moments = _sim_np(expected_params) + if isinstance(empirical_moments, dict): + empirical_moments = empirical_moments["simulated_moments"] + + estimate_msm( + simulate_moments=_sim_np, + empirical_moments=empirical_moments, + moments_cov=cov_np, + params=start_params, + optimize_options=scipy_optimizers.ScipyLBFGSB, + jacobian=lambda x: np.eye(len(x)), + ) + + +def test_estimate_msm_with_algorithm(): + start_params = np.array([3, 2, 1]) + expected_params = np.zeros(3) + empirical_moments = _sim_np(expected_params) + if isinstance(empirical_moments, dict): + empirical_moments = empirical_moments["simulated_moments"] + + estimate_msm( + simulate_moments=_sim_np, + empirical_moments=empirical_moments, + moments_cov=cov_np, + params=start_params, + optimize_options=scipy_optimizers.ScipyLBFGSB(stopping_maxfun=10), + jacobian=lambda x: np.eye(len(x)), + ) + + def test_to_pickle(tmp_path): start_params = np.array([3, 2, 1])