Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for loss function with auxiliary data in linesearch #1053

Open
ro0mquy opened this issue Sep 10, 2024 · 2 comments
Open

Support for loss function with auxiliary data in linesearch #1053

ro0mquy opened this issue Sep 10, 2024 · 2 comments

Comments

@ro0mquy
Copy link

ro0mquy commented Sep 10, 2024

I have a loss function that returns (loss_value, extra_data). Native jax supports this kind of construct with jax.value_and_grad(loss_fn, has_aux=True) (doc). The differentiated function returns ((loss_value, extra_data), grad).

In optax, when using the linesearch algorithms (for example as part of L-BFGS), I can use optax.value_and_grad_from_state(loss_fn) (doc) which uses the optimizer state to save function evaluations done inside the linesearch. Unfortunately, the linesearch algorithms and optax.value_and_grad_from_state don't support auxiliary data.

I added support for this to the optax code. It works for my use case. Are you interested in merging this upstream? I don't have time for proper testing, documentation, etc though, so would appreciate getting some assistance.

@vroulet
Copy link
Collaborator

vroulet commented Sep 10, 2024

Hello @ro0mquy,

I'd be happy to see how you handled it. I was not sure what would be the best solution to add this while keeping the API light. So if you have some example, I'd be happy to look at a PR.

Thanks!

@ro0mquy
Copy link
Author

ro0mquy commented Sep 11, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants