Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make use of jaxopt.StochasticSolver.run_iterator #194

Open
bagibence opened this issue Jul 18, 2024 · 1 comment
Open

Make use of jaxopt.StochasticSolver.run_iterator #194

bagibence opened this issue Jul 18, 2024 · 1 comment

Comments

@bagibence
Copy link
Collaborator

The example in the docs currently uses a custom loop to implement stochastic gradient descent.

An alternative would be to make use of jaxopt.StochasticSolver.run_iterator and add support for stochastic solvers, potentially including optimizers implemented in Optax through the jaxopt.OptaxSolver wrapper.

Additionally, for data that fits in memory, adding a faster version of this loop -- which implements sampling mini-batches and updating parameters (similarly to ProxSVRG.run in #184) -- could be useful.

@BalzaniEdoardo
Copy link
Collaborator

We could define super-class class that is a general interface for stochastic solvers (with inner/outer loop structure and the inner loop should be an abstract class, as well as the run_iterator for batching). The class should still provide the run method for data fitting in memory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants