Skip to content

Commit

Permalink
Rewrite the proposal for custom derivatives and add more examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
janosg committed Jun 28, 2024
1 parent 9fc18ce commit 7be41d3
Showing 1 changed file with 110 additions and 37 deletions.
147 changes: 110 additions & 37 deletions docs/source/development/eep-02-typing.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,12 @@ and therefore omitted here.

#### Proposal

The simplest way of specifying a least-squares function that does not return any logging
output becomes:
In the current situation, the dictionary return type solves two different problems that
will now be solved separately.

##### Specifying different problem types

The simplest way of specifying a least-squares function becomes:

```python
import estimagic as em
Expand Down Expand Up @@ -223,6 +227,8 @@ python libraries that support specialized optimizers (e.g.
`scipy.optimize.least_squares`). The reason why we need the decorators is that we
support all kinds of optimizers in the same interface.

##### Return additional information

If users additionally want to return information that should be stored in the log file,
they need to use a specific Object as return type.

Expand All @@ -233,8 +239,8 @@ class FunctionValue:
info: dict[str, Any]
```

Thus an example of a least-squares function that also returns additional info for the
log file would look like this:
An example of a least-squares function that also returns additional info for the log
file would look like this:

```python
from estimagic import FunctionValue
Expand All @@ -251,11 +257,13 @@ def least_squares_sphere(params):
And analogous for scalar and likelihood functions, where again the `mark.scalar`
decorator is optional.

##### Renaming `criterion` to `fun`

On top of the described changes, we suggest to rename `criterion` to `fun` to align the
naming with `scipy.optimize`. Similarly, `criterion_kwargs` should be renamed to
`fun_kwargs`.

#### Extension / Alternative
##### Optionally replace decorators by type hints

The purpose of the decorators is to tell us the output type of the criterion function.
This is necessary because there is no way of distinguishing between likelihood and
Expand Down Expand Up @@ -788,22 +796,89 @@ returns a tuple of the criterion value and the derivative instead.
- Autodiff needs to be handled completely outside of estimagic
- The names `criterion`, `derivative` and `criterion_and_derivative` are not aligned
with scipy and very long.
- Providing derivatives to estimagic is perceived as complicated and confusing.

#### Proposal

1. We keep the three arguments but rename them to `fun`, `jac` and `fun_and_jac`.
1. `jac` or `fun_and_jac` can also be a string `"jax"` or a more autocomplete friendly
enum `em.autodiff_backend.JAX`. This can be used to signal that the objective
function is jax compatible and jax should be used to calculate its derivatives. In
the long run we can add PyTorch support and more. Since this is mostly about a signal
of compatibility, it would be enough to set one of the two arguments to `"jax"`, the
other one can be left at `None`.
1. The dictionaries of callables get replaced by appropriate dataclasses. We align the
names with the names in `FunctionValue` (e.g. rename `root_contributions` to
`residuals`).
1. If a single callable is passed, it is assumed to be the gradient of the scalar
function value unless it is decorated with `em.mark.least_squares` or
`em.mark.likelihood`.
We keep the three arguments but rename them to `fun`, `jac` and `fun_and_jac`.

To improve the integration with modern automatic differentiation frameworks, `jac` or
`fun_and_jac` can also be a string `"jax"` or a more autocomplete friendly enum
`em.autodiff_backend.JAX`. This can be used to signal that the objective function is jax
compatible and jax should be used to calculate its derivatives. In the long run we can
add PyTorch support and more. Since this is mostly about a signal of compatibility, it
would be enough to set one of the two arguments to `"jax"`, the other one can be left at
`None`. Here is an example:

```python
import jax.numpy as jnp
import estimagic as em


def jax_sphere(x):
return jnp.dot(x, x)


res = em.minimize(
fun=jax_sphere,
params=jnp.arange(5),
algorithm=em.algorithms.scipy_lbfgsb,
jac="jax",
)
```

If a custom callable is provided as `jac` or `fun_and_jac`, it needs to be decorated
with `@em.mark.least_squares` or `em.mark.likelihood` if it is not the gradient of a
scalar function values. Using the `em.mark.scalar` decorator is optional. For a simple
least-squares problem this looks as follows:

```python
import numpy as np


@em.mark.least_squares
def ls_sphere(params):
return params


@em.mark.least_squares
def ls_sphere_jac(params):
return np.eye(len(params))


res = em.minimize(
fun=ls_sphere,
params=np.arange(5),
algorithm=em.algorithms.scipy_ls_lm,
jac=ls_sphere_jac,
)
```

Note that here we have a least-squares problem and solve it with a least-squares
optimizer. However, any least-squares problem can also be solved with scalar optimizers.

While estimagic could convert the least-squares derivative to the gradient of the scalar
function value, this is generally inefficient. Therefore, a user can provide multiple
callables of the objective function in such a case, so we can pick the best one for the
chosen optimizer.

```python
@em.mark.scalar
def sphere_grad(params):
return 2 * params


res = em.minimize(
fun=ls_sphere,
params=np.arange(5),
algorithm=em.algorithms.scipy_lbfgsb,
jac=[ls_sphere_jac, sphere_grad],
)
```

Since a scalar optimizer was chosen to solve the least-squares problem, estimagic would
pick the `sphere_grad` as derivative. If a leas-squares solver was chosen, we would use
`ls_sphere_jac`.

### Other option dictionaries

Expand Down Expand Up @@ -1461,25 +1536,23 @@ beartype during testing but it is not a priority for now.

### Suggested changes

| **Old Name** | **Proposed Name** | **Source** |
| ------------------------------------------ | ------------------------- | ----------- |
| `criterion` | `fun` | scipy |
| `criterion_kwargs` | `fun_kwargs` | |
| `derivative` | `jac` | scipy |
| `derivative_kwargs` | `jac_kwargs` | |
| `criterion_and_derivative` | `fun_and_jac` | |
| `criterion_and_derivative_kwargs` | `fun_and_jac_kwargs` | |
| `stopping_max_criterion_evaluations` | `stopping_maxfun` | scipy |
| `stopping_max_iterations` | `stopping_maxiter` | scipy |
| `convergence_absolute_criterion_tolerance` | `convergence_ftol_abs` | NlOpt |
| `convergence_relative_criterion_tolerance` | `convergence_ftol_rel` | NlOpt |
| `convergence_absolute_params_tolerance` | `convergence_xtol_abs` | NlOpt |
| `convergence_relative_params_tolerance` | `convergence_xtol_rel` | NlOpt |
| `convergence_absolute_gradient_tolerance` | `convergence_gtol_abs` | NlOpt |
| `convergence_relative_gradient_tolerance` | `convergence_gtol_rel` | NlOpt |
| `convergence_scaled_gradient_tolerance` | `convergence_gtol_scaled` | |
| `root_contributions` | `residuals` | Literature |
| `contributions` | `loglikes` | Statsmodels |
| **Old Name** | **Proposed Name** | **Source** |
| ------------------------------------------ | ------------------------- | ---------- |
| `criterion` | `fun` | scipy |
| `criterion_kwargs` | `fun_kwargs` | |
| `derivative` | `jac` | scipy |
| `derivative_kwargs` | `jac_kwargs` | |
| `criterion_and_derivative` | `fun_and_jac` | |
| `criterion_and_derivative_kwargs` | `fun_and_jac_kwargs` | |
| `stopping_max_criterion_evaluations` | `stopping_maxfun` | scipy |
| `stopping_max_iterations` | `stopping_maxiter` | scipy |
| `convergence_absolute_criterion_tolerance` | `convergence_ftol_abs` | NlOpt |
| `convergence_relative_criterion_tolerance` | `convergence_ftol_rel` | NlOpt |
| `convergence_absolute_params_tolerance` | `convergence_xtol_abs` | NlOpt |
| `convergence_relative_params_tolerance` | `convergence_xtol_rel` | NlOpt |
| `convergence_absolute_gradient_tolerance` | `convergence_gtol_abs` | NlOpt |
| `convergence_relative_gradient_tolerance` | `convergence_gtol_rel` | NlOpt |
| `convergence_scaled_gradient_tolerance` | `convergence_gtol_scaled` | |

### Things we do not want to align

Expand Down

0 comments on commit 7be41d3

Please sign in to comment.