Skip to content

Commit

Permalink
Add x0 as alias for params.
Browse files Browse the repository at this point in the history
  • Loading branch information
janosg committed Jul 17, 2024
1 parent d781907 commit 64cf934
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/optimagic/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class MissingInputError(OptimagicError):
"""Exception for missing user provided input."""


class AliasError(OptimagicError):
"""Exception for aliasing errors."""


class InvalidKwargsError(OptimagicError):
"""Exception for invalid user provided keyword arguments."""

Expand Down
27 changes: 26 additions & 1 deletion src/optimagic/optimization/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
InvalidFunctionError,
InvalidKwargsError,
MissingInputError,
AliasError,
)
from optimagic.logging.create_tables import (
make_optimization_iteration_table,
Expand Down Expand Up @@ -63,6 +64,8 @@ def maximize(
multistart_options=None,
collect_history=True,
skip_checks=False,
# scipy aliases
x0=None,
# deprecated arguments
criterion=None,
criterion_kwargs=None,
Expand Down Expand Up @@ -99,6 +102,8 @@ def maximize(
multistart_options=multistart_options,
collect_history=collect_history,
skip_checks=skip_checks,
# scipy aliases
x0=x0,
# deprecated arguments
criterion=criterion,
criterion_kwargs=criterion_kwargs,
Expand Down Expand Up @@ -136,6 +141,8 @@ def minimize(
multistart_options=None,
collect_history=True,
skip_checks=False,
# scipy aliases
x0=None,
# deprecated arguments
criterion=None,
criterion_kwargs=None,
Expand Down Expand Up @@ -173,6 +180,8 @@ def minimize(
multistart_options=multistart_options,
collect_history=collect_history,
skip_checks=skip_checks,
# scipy aliases
x0=x0,
# deprecated arguments
criterion=criterion,
criterion_kwargs=criterion_kwargs,
Expand Down Expand Up @@ -211,6 +220,8 @@ def _optimize(
multistart_options,
collect_history,
skip_checks,
# scipy aliases
x0,
# deprecated arguments
criterion,
criterion_kwargs,
Expand Down Expand Up @@ -239,7 +250,7 @@ def _optimize(
)
raise MissingInputError(msg)

if params is None:
if params is None and x0 is None:
msg = (
"Missing start parameters. Please provide start parameters as the second "
"positional argument or as the keyword argument `params`."
Expand Down Expand Up @@ -328,6 +339,20 @@ def _optimize(
if fun_and_jac_kwargs is None:
fun_and_jac_kwargs = criterion_and_derivative_kwargs

# ==================================================================================
# handle scipy aliases
# ==================================================================================

if x0 is not None:
if params is not None:
msg = (
"x0 is an alias for params (for better compatibility with scipy). "
"Do not use both x0 and params."
)
raise AliasError(msg)
else:
params = x0

# ==================================================================================
# Set default values and check options
# ==================================================================================
Expand Down
43 changes: 43 additions & 0 deletions tests/optimagic/optimization/test_scipy_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import optimagic as om
import numpy as np
from numpy.testing import assert_array_almost_equal as aaae
from optimagic.exceptions import AliasError
import pytest


def test_x0_works_in_minimize():
res = om.minimize(
fun=lambda x: x @ x,
x0=np.arange(3),
algorithm="scipy_lbfgsb",
)
aaae(res.params, np.zeros(3))


def test_x0_works_in_maximize():
res = om.maximize(
fun=lambda x: -x @ x,
x0=np.arange(3),
algorithm="scipy_lbfgsb",
)
aaae(res.params, np.zeros(3))


def test_x0_and_params_do_not_work_together_in_minimize():
with pytest.raises(AliasError, match="x0 is an alias"):
om.minimize(
fun=lambda x: x @ x,
x0=np.arange(3),
params=np.arange(3),
algorithm="scipy_lbfgsb",
)


def test_x0_and_params_do_not_work_together_in_maximize():
with pytest.raises(AliasError, match="x0 is an alias"):
om.maximize(
fun=lambda x: -x @ x,
x0=np.arange(3),
params=np.arange(3),
algorithm="scipy_lbfgsb",
)

0 comments on commit 64cf934

Please sign in to comment.