Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add (Prox-)SVRG as a solver for GLMs #184

Merged
merged 198 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
198 commits
Select commit Hold shift + click to select a range
f1de528
Start working on the Prox-SVRG solver
bagibence Jun 21, 2024
ad45bc8
Adapt SVRG and ProxSVRG for use with nemos and allow their use in reg…
bagibence Jun 24, 2024
1c7a9d9
Make ProxSVRG inherit from SVRG and implement early stopping
bagibence Jun 24, 2024
7f5d27b
Speed up (Prox)SVRG with jit, scan, while_loop
bagibence Jun 25, 2024
ef1983d
Remove unnecessary jits
bagibence Jun 25, 2024
3bfa251
Make prox required
bagibence Jun 25, 2024
df751dd
Simplify from scan to fori_loop in update
bagibence Jun 25, 2024
50ea86f
Tiny changes (whitespace and tree_zeros_like)
bagibence Jun 25, 2024
f016f10
Save the initial naive implementation of (Prox)SVRG in solvers_naive
bagibence Jun 25, 2024
56f49a2
Add batch_size and type annotations for (Prox)SVRG
bagibence Jun 25, 2024
47429f9
Rename lr to stepsize in (Prox)SVRG
bagibence Jun 25, 2024
d09fc8d
Move stepsize into state, redifine error, adjust m based on batch_size
bagibence Jun 26, 2024
901adba
Add a (commented out) update method that implements just the inner lo…
bagibence Jun 26, 2024
ac10fa4
Change SVRG.update to update parameters on every data point
bagibence Jun 27, 2024
4b4d3b5
Change epoch_num back to iter_num
bagibence Jun 27, 2024
8905008
Pass args to SVRG.init_state to calculate N, xs, df_xs
bagibence Jun 27, 2024
32dcf6c
Separately define run and update in SVRG and ProxSVRG
bagibence Jun 27, 2024
eb520f1
Fix wrong argument passed to solver.init_state
bagibence Jun 27, 2024
a25e9af
Switch inheritance between SVRG and ProxSVRG
bagibence Jun 27, 2024
4e89a5f
Store the solver in the regularizer
bagibence Jun 27, 2024
372b233
Minor renaming and sanity check
bagibence Jun 27, 2024
34ec663
Remove stale and unreachable code
bagibence Jun 27, 2024
1c483d0
Update solvers to work with mini-batch updates
bagibence Jun 28, 2024
30c7e59
Add x_av to SVRGState and use it in the return value of ProxSVRG.update
bagibence Jun 28, 2024
a48fe7e
Log the loss after each update
bagibence Jun 28, 2024
fa95695
Move the loss logging into the run method
bagibence Jun 28, 2024
1b855e5
Have separate _update_per_batch and _update_per_point methods
bagibence Jul 2, 2024
ccd554e
Add batch_size again and have different update methods
bagibence Jul 8, 2024
811e48b
Use different method for run based on batch size
bagibence Jul 8, 2024
f1618ae
Potentially pre-generate random indices instead of one-by-one
bagibence Jul 10, 2024
f7825cc
Add docstrings and clean up ProxSVRG
bagibence Jul 11, 2024
efc507f
Change how solvers are found: first look in solvers, then in jaxopt
bagibence Jul 11, 2024
e489a15
Switch to new-style RNG keys and remove the uint32 conversion
bagibence Jul 11, 2024
f6b579a
Change default convergence tolerance in (Prox)SVRG
bagibence Jul 12, 2024
3bce956
Add yet another way of defining the error
bagibence Jul 12, 2024
75ec116
Add missing import and change the error calculation to the previous one
bagibence Jul 12, 2024
4b6064f
Make parameters explicit where possible and remove indexing into the …
bagibence Jul 12, 2024
85b366d
added regr testing
BalzaniEdoardo Jul 12, 2024
599c390
merged conflicts
BalzaniEdoardo Jul 12, 2024
7d73a39
improved test data generation
BalzaniEdoardo Jul 12, 2024
212753f
linted conftest.py
BalzaniEdoardo Jul 12, 2024
280221c
added test for tree utils
BalzaniEdoardo Jul 12, 2024
a911fbe
updated description of tree_slice
BalzaniEdoardo Jul 12, 2024
cce74f2
add context manager to test
BalzaniEdoardo Jul 12, 2024
b3c0bf3
remove unused imports
BalzaniEdoardo Jul 12, 2024
41d5d01
uniform srtepsize
BalzaniEdoardo Jul 12, 2024
5759d18
drop unused line
BalzaniEdoardo Jul 12, 2024
1f3d659
FIXED TOL
BalzaniEdoardo Jul 12, 2024
6faa685
Reformat with black
bagibence Jul 15, 2024
66c73ac
Initialize loss_log with zeros instead of empty and run isort
bagibence Jul 15, 2024
ff09dce
Fix flake8's complaints
bagibence Jul 17, 2024
ae99cfc
Add a few tests for SVRG
bagibence Jul 18, 2024
ccbd41f
Add integration tests for (Prox)SVRG
bagibence Jul 19, 2024
b9644bc
Increment iter_num after each xk_update and comment out loss logging
bagibence Jul 19, 2024
549fdfc
Revert the iter_num increment frequency
bagibence Jul 19, 2024
4e66c9e
Test that GLM.fit runs with (Prox)SVRG
bagibence Jul 19, 2024
7ba0a3b
Test that calls to SVRG.update with a naive implementation of the res…
bagibence Jul 19, 2024
0816f13
Remove asserts and update test
bagibence Jul 19, 2024
be9ea0f
Change iter_num incrementation again and update the test for it
bagibence Jul 19, 2024
ee7b6e4
Change iter_num again to mean number of epochs and adjust test
bagibence Jul 19, 2024
bb3a2bf
Copy tree_utils from jaxopt
bagibence Jul 19, 2024
cb626a1
Add test for SVRG._xk_update
bagibence Jul 19, 2024
13f7b43
Use tree_slice of y as well
bagibence Jul 22, 2024
d97bd9b
Update tests
bagibence Jul 22, 2024
12fd6b7
fixed pop glm proximal gradient
BalzaniEdoardo Jul 22, 2024
32f8b02
Switch from X,y to *args in ProxSVRG. Leave prox_lamba explicit
bagibence Jul 23, 2024
6306313
Add ProxSVRG._run to handle prox_lambda better
bagibence Jul 23, 2024
953ef6f
fixing prox_lasso
BalzaniEdoardo Jul 23, 2024
9b3eea3
use our own lasso
BalzaniEdoardo Jul 23, 2024
93961be
added test for prox_lasso
BalzaniEdoardo Jul 23, 2024
5a10d8f
Test SVRG._xk_update with our own prox_lasso instead of jaxopt's
bagibence Jul 23, 2024
2344895
Test that GLM.fit converges to the same as sklearn's PoissonRegressor…
bagibence Jul 23, 2024
a4f588c
Remove prints and spaces
bagibence Jul 23, 2024
7838aab
Test that GLM.update with a naive SVRG outer loop gives the same as G…
bagibence Jul 23, 2024
c4e1a7a
Start working on the Prox-SVRG solver
bagibence Jun 21, 2024
f4dbb0e
Adapt SVRG and ProxSVRG for use with nemos and allow their use in reg…
bagibence Jun 24, 2024
6549266
Make ProxSVRG inherit from SVRG and implement early stopping
bagibence Jun 24, 2024
763be7b
Speed up (Prox)SVRG with jit, scan, while_loop
bagibence Jun 25, 2024
c967fb0
Remove unnecessary jits
bagibence Jun 25, 2024
5bb47cf
Make prox required
bagibence Jun 25, 2024
33d6c7b
Simplify from scan to fori_loop in update
bagibence Jun 25, 2024
4e2393d
Tiny changes (whitespace and tree_zeros_like)
bagibence Jun 25, 2024
25c44b9
Save the initial naive implementation of (Prox)SVRG in solvers_naive
bagibence Jun 25, 2024
0574d93
Add batch_size and type annotations for (Prox)SVRG
bagibence Jun 25, 2024
8278f45
Rename lr to stepsize in (Prox)SVRG
bagibence Jun 25, 2024
d881321
Move stepsize into state, redifine error, adjust m based on batch_size
bagibence Jun 26, 2024
3c2d560
Add a (commented out) update method that implements just the inner lo…
bagibence Jun 26, 2024
4c5bf27
Change SVRG.update to update parameters on every data point
bagibence Jun 27, 2024
245bdaa
Change epoch_num back to iter_num
bagibence Jun 27, 2024
6553737
Pass args to SVRG.init_state to calculate N, xs, df_xs
bagibence Jun 27, 2024
5f629a7
Separately define run and update in SVRG and ProxSVRG
bagibence Jun 27, 2024
6bcdb3f
Switch inheritance between SVRG and ProxSVRG
bagibence Jun 27, 2024
b0d8b93
Minor renaming and sanity check
bagibence Jun 27, 2024
1f12c46
Remove stale and unreachable code
bagibence Jun 27, 2024
e0a5d3f
Update solvers to work with mini-batch updates
bagibence Jun 28, 2024
26481e9
Add x_av to SVRGState and use it in the return value of ProxSVRG.update
bagibence Jun 28, 2024
6850f7e
Log the loss after each update
bagibence Jun 28, 2024
7461caf
Move the loss logging into the run method
bagibence Jun 28, 2024
d9539d4
Have separate _update_per_batch and _update_per_point methods
bagibence Jul 2, 2024
e068b3b
Add batch_size again and have different update methods
bagibence Jul 8, 2024
e26af13
Use different method for run based on batch size
bagibence Jul 8, 2024
ec24269
Potentially pre-generate random indices instead of one-by-one
bagibence Jul 10, 2024
ebfe409
Add docstrings and clean up ProxSVRG
bagibence Jul 11, 2024
b52b71d
Switch to new-style RNG keys and remove the uint32 conversion
bagibence Jul 11, 2024
7c8e7d5
Change default convergence tolerance in (Prox)SVRG
bagibence Jul 12, 2024
a599cd9
Add yet another way of defining the error
bagibence Jul 12, 2024
d94edf5
Add missing import and change the error calculation to the previous one
bagibence Jul 12, 2024
5343a15
added regr testing
BalzaniEdoardo Jul 12, 2024
2044c3f
Make parameters explicit where possible and remove indexing into the …
bagibence Jul 12, 2024
5c600db
improved test data generation
BalzaniEdoardo Jul 12, 2024
cd47403
linted conftest.py
BalzaniEdoardo Jul 12, 2024
a668acf
added test for tree utils
BalzaniEdoardo Jul 12, 2024
00df370
updated description of tree_slice
BalzaniEdoardo Jul 12, 2024
761ab73
add context manager to test
BalzaniEdoardo Jul 12, 2024
f0be599
remove unused imports
BalzaniEdoardo Jul 12, 2024
194b432
uniform srtepsize
BalzaniEdoardo Jul 12, 2024
ecd4ef4
drop unused line
BalzaniEdoardo Jul 12, 2024
e17a495
FIXED TOL
BalzaniEdoardo Jul 12, 2024
d10e507
Initialize loss_log with zeros instead of empty and run isort
bagibence Jul 15, 2024
41b627f
Fix flake8's complaints
bagibence Jul 17, 2024
bc32794
Add a few tests for SVRG
bagibence Jul 18, 2024
846ed43
Add integration tests for (Prox)SVRG
bagibence Jul 19, 2024
8012176
Increment iter_num after each xk_update and comment out loss logging
bagibence Jul 19, 2024
a75f55b
Revert the iter_num increment frequency
bagibence Jul 19, 2024
7fb3d9d
Test that GLM.fit runs with (Prox)SVRG
bagibence Jul 19, 2024
f751c76
Test that calls to SVRG.update with a naive implementation of the res…
bagibence Jul 19, 2024
ad118c7
Remove asserts and update test
bagibence Jul 19, 2024
21273eb
Change iter_num incrementation again and update the test for it
bagibence Jul 19, 2024
4d8fb3f
Change iter_num again to mean number of epochs and adjust test
bagibence Jul 19, 2024
adbe0a2
Copy tree_utils from jaxopt
bagibence Jul 19, 2024
56f1ff3
Add test for SVRG._xk_update
bagibence Jul 19, 2024
8f743df
Use tree_slice of y as well
bagibence Jul 22, 2024
1ca062f
Update tests
bagibence Jul 22, 2024
f51f6d2
fixed pop glm proximal gradient
BalzaniEdoardo Jul 22, 2024
3004c14
Switch from X,y to *args in ProxSVRG. Leave prox_lamba explicit
bagibence Jul 23, 2024
eac538e
Add ProxSVRG._run to handle prox_lambda better
bagibence Jul 23, 2024
0d673cc
fixing prox_lasso
BalzaniEdoardo Jul 23, 2024
457397e
use our own lasso
BalzaniEdoardo Jul 23, 2024
d58cb16
added test for prox_lasso
BalzaniEdoardo Jul 23, 2024
2ae4a13
Test SVRG._xk_update with our own prox_lasso instead of jaxopt's
bagibence Jul 23, 2024
a885cda
Test that GLM.fit converges to the same as sklearn's PoissonRegressor…
bagibence Jul 23, 2024
3de8780
Test that GLM.update with a naive SVRG outer loop gives the same as G…
bagibence Jul 23, 2024
231c0a6
Bit of tidying up in ProxSVRG
bagibence Jul 26, 2024
82fa1a6
Make BaseRegressor aware of nemos.solvers
bagibence Jul 26, 2024
efb97c9
Allow using SVRG and ProxSVRG with regularizers
bagibence Jul 26, 2024
83b093c
Update tests
bagibence Jul 26, 2024
7b0ff57
reconciled history
BalzaniEdoardo Jul 26, 2024
06ed5c7
updated reg
BalzaniEdoardo Jul 26, 2024
cc11af4
fixed error message and test
BalzaniEdoardo Jul 26, 2024
710f484
linting (with partial flake8 on test)
BalzaniEdoardo Jul 26, 2024
9fffdd1
bugfixed mask creation
BalzaniEdoardo Jul 26, 2024
52e9c2e
fixed ellipsis
BalzaniEdoardo Jul 26, 2024
500760a
fixed flake8
BalzaniEdoardo Jul 26, 2024
81a5f0f
removed ellipsis
BalzaniEdoardo Jul 26, 2024
cf56b89
removed ellipsis
BalzaniEdoardo Jul 26, 2024
ff9b9b9
Only store the loss function in BaseRegressor, not the solver
bagibence Jul 26, 2024
2dbe45a
Remove unused solvers_naive.py
bagibence Aug 2, 2024
1f44561
Remove logging the loss in ProxSVRG
bagibence Aug 2, 2024
f327cac
Remove x_av and x_sum in ProxSVRG
bagibence Aug 2, 2024
3f39005
Change default random key from 0 to 123
bagibence Aug 2, 2024
e54726b
Remove unused error calculations and add docstring for the used one i…
bagibence Aug 2, 2024
6926bdc
Apply docstring spelling suggestions from code review
bagibence Aug 2, 2024
cb610dd
Start improving citations
bagibence Aug 5, 2024
4a66ac7
Correct references
bagibence Aug 5, 2024
e8ddf3f
Remove unused imports
bagibence Aug 5, 2024
01a39b4
Remove iteration through args when using `tree_slice`
bagibence Aug 5, 2024
335860a
Remove type hints from the docstrings
bagibence Aug 5, 2024
666afa2
Update description of *args in ProxSVRG methods
bagibence Aug 5, 2024
46653f9
Type hint parameters as Pytree
bagibence Aug 5, 2024
d3a904f
... finishing the previous commit
bagibence Aug 5, 2024
0285b78
Add "Raises" sections in (Prox)SVRG
bagibence Aug 5, 2024
2bad7ac
Correct default random key in SVRG tests
bagibence Aug 5, 2024
e1293f4
Attempt to use more descriptive variable names
bagibence Aug 5, 2024
167c8dd
Fix example formatting in (Prox)SVRG docstrings
bagibence Aug 6, 2024
ca354fe
Make (Prox)SVRG.update docstring more precise and generic
bagibence Aug 6, 2024
41aa456
Remove duplicate comment
bagibence Aug 6, 2024
bdfc35b
Add docstring for SVRGState and update references
bagibence Aug 7, 2024
e5904aa
renamed fileds
BalzaniEdoardo Aug 8, 2024
214c9da
Change variable naming to be more descriptive
bagibence Aug 8, 2024
ce50fbd
Fix references in the docstrings
bagibence Aug 8, 2024
bfc4cfc
Remove init_full_gradient argument from SVRG.init_state
bagibence Aug 8, 2024
46fe498
merged development
BalzaniEdoardo Aug 10, 2024
cd4660b
flake8 fixes
BalzaniEdoardo Aug 10, 2024
61478a4
merged with origin
BalzaniEdoardo Aug 10, 2024
8c74bd9
fixed a deprecation warning
BalzaniEdoardo Aug 10, 2024
5f33bd6
merged dev
BalzaniEdoardo Aug 13, 2024
493f090
removed unused arg
BalzaniEdoardo Aug 13, 2024
3987254
Update src/nemos/solvers.py
BalzaniEdoardo Aug 13, 2024
0c3aabb
removed line
BalzaniEdoardo Aug 13, 2024
ca0b51a
partial solution to references
BalzaniEdoardo Aug 13, 2024
4eb3925
fix references
BalzaniEdoardo Aug 14, 2024
bdb1d3d
linted
BalzaniEdoardo Aug 14, 2024
1537645
merged
BalzaniEdoardo Aug 14, 2024
d39ea73
Merge branch 'development' into svrg
BalzaniEdoardo Aug 14, 2024
f5528bc
Update src/nemos/solvers.py
BalzaniEdoardo Aug 15, 2024
0ead1bf
Merge branch 'development' into svrg
BalzaniEdoardo Aug 15, 2024
a318e33
improved docstrings
BalzaniEdoardo Aug 15, 2024
0d539d2
linted
BalzaniEdoardo Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,7 @@ def _check_n_basis_min(self) -> None:

class MSplineBasis(SplineBasis):
r"""
M-spline[$^1$](#references) basis functions for modeling and data transformation.
M-spline[$^{[1]}$](#references) basis functions for modeling and data transformation.

M-splines are a type of spline basis function used for smooth curve fitting
and data representation. They are positive and integrate to one, making them
Expand Down Expand Up @@ -1394,8 +1394,8 @@ class MSplineBasis(SplineBasis):
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = mspline_basis(sample_points)

References
----------
# References
------------
[1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science,
3(4), 425-441.

Expand Down Expand Up @@ -1517,7 +1517,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:

class BSplineBasis(SplineBasis):
"""
B-spline[$^1$](#references) 1-dimensional basis functions.
B-spline[$^{[1]}$](#references) 1-dimensional basis functions.

Parameters
----------
Expand Down Expand Up @@ -1546,9 +1546,9 @@ class BSplineBasis(SplineBasis):
Spline order.


References
----------
1. Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
# References
------------
[1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5

"""
Expand Down Expand Up @@ -1779,7 +1779,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
class RaisedCosineBasisLinear(Basis):
"""Represent linearly-spaced raised cosine basis functions.

This implementation is based on the cosine bumps used by Pillow et al.[$^1$](#references)
This implementation is based on the cosine bumps used by Pillow et al.[$^{[1]}$](#references)
to uniformly tile the internal points of the domain.

Parameters
Expand All @@ -1801,9 +1801,9 @@ class RaisedCosineBasisLinear(Basis):
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`

References
----------
1. Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
# References
------------
[1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
C. E. (2005). Prediction and decoding of retinal ganglion cell responses
with a probabilistic spiking model. Journal of Neuroscience, 25(47),
11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005
Expand Down Expand Up @@ -1964,7 +1964,7 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear):
"""Represent log-spaced raised cosine basis functions.

Similar to `RaisedCosineBasisLinear` but the basis functions are log-spaced.
This implementation is based on the cosine bumps used by Pillow et al.[$^1$](#references)
This implementation is based on the cosine bumps used by Pillow et al.[$^{[1]}$](#references)
to uniformly tile the internal points of the domain.

Parameters
Expand Down Expand Up @@ -1994,9 +1994,9 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear):
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`

References
----------
1. Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
# References
------------
[1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
C. E. (2005). Prediction and decoding of retinal ganglion cell responses
with a probabilistic spiking model. Journal of Neuroscience, 25(47),
11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005
Expand Down
14 changes: 7 additions & 7 deletions src/nemos/observation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ def pseudo_r2(
) -> jnp.ndarray:
r"""Pseudo-$R^2$ calculation for a GLM.

Compute the pseudo-$R^2$ metric for the GLM, as defined by McFadden et al.[$^1$](#references)
or by Cohen et al.[$^2$](#references).
Compute the pseudo-$R^2$ metric for the GLM, as defined by McFadden et al.[$^{[1]}$](#references)
or by Cohen et al.[$^{[2]}$](#references).

This metric evaluates the goodness-of-fit of the model relative to a null (baseline) model that assumes a
constant mean for the observations. While the pseudo-$R^2$ is bounded between 0 and 1 for the training set,
Expand Down Expand Up @@ -311,13 +311,13 @@ def pseudo_r2(
sample, i.e. the maximum value that the likelihood could possibly achieve). $D_M$ and $D_0$ are
the model and the null deviance, $D_i = -2 \left[ \log(L_s) - \log(L_i) \right]$ for $i=M,0$.


References
----------
1. McFadden D (1979). Quantitative methods for analysing travel behavior of individuals: Some recent
# References
------------
[1] McFadden D (1979). Quantitative methods for analysing travel behavior of individuals: Some recent
developments. In D. A. Hensher & P. R. Stopher (Eds.), *Behavioural travel modelling* (pp. 279-318).
London: Croom Helm.
2. Jacob Cohen, Patricia Cohen, Steven G. West, Leona S. Aiken.

[2] Jacob Cohen, Patricia Cohen, Steven G. West, Leona S. Aiken.
*Applied Multiple Regression/Correlation Analysis for the Behavioral Sciences*.
3rd edition. Routledge, 2002. p.502. ISBN 978-0-8058-2223-6. (May 2012)
"""
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def difference_of_gammas(

References
----------
1. [SciPy Docs - "scipy.stats.gamma"](https://docs.scipy.org/doc/
[1] [SciPy Docs - "scipy.stats.gamma"](https://docs.scipy.org/doc/
scipy/reference/generated/scipy.stats.gamma.html)
"""
# check that the gamma parameters are positive (scipy returns
Expand Down
39 changes: 18 additions & 21 deletions src/nemos/solvers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Callable, NamedTuple, Optional, Union
from typing import Callable, NamedTuple, Optional, Union

import jax
import jax.flatten_util
Expand Down Expand Up @@ -30,15 +30,15 @@ class SVRGState(NamedTuple):
Step size of the individual gradient steps.
reference_point :
Anchor/reference/snapshot point where the full gradient is calculated in the SVRG algorithm.
Corresponds to `x_{s}` in the pseudocode in [1]
Corresponds to `x_{s}` in the pseudocode[$^{[1]}$](#references).
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved
full_grad_at_reference_point :
Full gradient at the anchor/reference point.

References
----------
1. [Gower, Robert M., Mark Schmidt, Francis Bach, and Peter Richtárik.
"Variance-Reduced Methods for Machine Learning." arXiv preprint arXiv:2010.00892 (2020).
](https://arxiv.org/abs/2010.00892)
# References
------------
[1] [Gower, Robert M., Mark Schmidt, Francis Bach, and Peter Richtárik.
"Variance-Reduced Methods for Machine Learning." arXiv preprint arXiv:2010.00892 (2020).
](https://arxiv.org/abs/2010.00892)
"""

iter_num: int
Expand Down Expand Up @@ -88,13 +88,15 @@ class ProxSVRG:

References
----------
1. [Gower, Robert M., Mark Schmidt, Francis Bach, and Peter Richtárik.
[1] [Gower, Robert M., Mark Schmidt, Francis Bach, and Peter Richtárik.
"Variance-Reduced Methods for Machine Learning." arXiv preprint arXiv:2010.00892 (2020).
](https://arxiv.org/abs/2010.00892)
2. [Xiao, Lin, and Tong Zhang.

[2] [Xiao, Lin, and Tong Zhang.
"A proximal stochastic gradient method with progressive variance reduction."
SIAM Journal on Optimization 24.4 (2014): 2057-2075.](https://arxiv.org/abs/1403.4699v1)
3. [Johnson, Rie, and Tong Zhang.

[3] [Johnson, Rie, and Tong Zhang.
"Accelerating stochastic gradient descent using predictive variance reduction."
Advances in neural information processing systems 26 (2013).
](https://proceedings.neurips.cc/paper/2013/hash/ac1dd209cbcc5e5d1c6e28598e8cbbe8-Abstract.html)
Expand Down Expand Up @@ -122,7 +124,6 @@ def __init__(
def init_state(
self,
init_params: Pytree,
hyperparams_prox: Any,
*args,
) -> SVRGState:
"""
Expand All @@ -133,9 +134,6 @@ def init_state(
init_params :
Pytree containing the initial parameters.
For GLMs it's a tuple of (W, b)
hyperparams_prox :
Parameters of the proximal operator, in our case the regularization strength.
Not used here, but required to be consistent with the jaxopt API.
args:
Positional arguments passed to loss function `fun` and its gradient (e.g. `fun(params, *args)`),
most likely input and output data.
Expand Down Expand Up @@ -384,7 +382,6 @@ def run(
# initialize the state, including the full gradient at the initial parameters
init_state = self.init_state(
init_params,
prox_lambda,
*args,
)

Expand Down Expand Up @@ -535,7 +532,6 @@ def _update_per_random_samples(
N = n_points_per_arg.pop()

m = (N + self.batch_size - 1) // self.batch_size # number of iterations
# m = N

def inner_loop_body(_, carry):
params, key = carry
Expand Down Expand Up @@ -625,12 +621,14 @@ class SVRG(ProxSVRG):

References
----------
1. [Gower, Robert M., Mark Schmidt, Francis Bach, and Peter Richtárik.
[1] [Gower, Robert M., Mark Schmidt, Francis Bach, and Peter Richtárik.
"Variance-Reduced Methods for Machine Learning." arXiv preprint arXiv:2010.00892 (2020).
](https://arxiv.org/abs/2010.00892)
2. [Xiao, Lin, and Tong Zhang. "A proximal stochastic gradient method with progressive variance reduction."

[2] [Xiao, Lin, and Tong Zhang. "A proximal stochastic gradient method with progressive variance reduction."
SIAM Journal on Optimization 24.4 (2014): 2057-2075.](https://arxiv.org/abs/1403.4699v1)
3. [Johnson, Rie, and Tong Zhang. "Accelerating stochastic gradient descent using predictive variance reduction."

[3] [Johnson, Rie, and Tong Zhang. "Accelerating stochastic gradient descent using predictive variance reduction."
Advances in neural information processing systems 26 (2013).
](https://proceedings.neurips.cc/paper/2013/hash/ac1dd209cbcc5e5d1c6e28598e8cbbe8-Abstract.html)
"""
Expand Down Expand Up @@ -662,7 +660,6 @@ def init_state(self, init_params: Pytree, *args, **kwargs) -> SVRGState:
----------
init_params :
pytree containing the initial parameters.
For GLMs it's a tuple of (W, b)
args:
Positional arguments passed to loss function `fun` and its gradient (e.g. `fun(params, *args)`),
most likely input and output data.
Expand All @@ -680,7 +677,7 @@ def init_state(self, init_params: Pytree, *args, **kwargs) -> SVRGState:
Initialized optimizer state
"""
# substitute None for prox_lambda
return super().init_state(init_params, None, *args, **kwargs)
return super().init_state(init_params, *args, **kwargs)

@partial(jit, static_argnums=(0,))
def update(self, params: Pytree, state: SVRGState, *args, **kwargs) -> OptStep:
Expand Down
6 changes: 3 additions & 3 deletions src/nemos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,9 @@ def row_wise_kron(
This function computes the row-wise Kronecker product between dense matrices A and C
using JAX for automatic differentiation and GPU acceleration.

References
----------
1. Petersen, Kaare Brandt, and Michael Syskind Pedersen. "The matrix cookbook."
# References
------------
[1] Petersen, Kaare Brandt, and Michael Syskind Pedersen. "The matrix cookbook."
Technical University of Denmark 7.15 (2008): 510.
"""
if transpose:
Expand Down
29 changes: 0 additions & 29 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,6 @@ def test_identifiability_constraint_apply(self):
assert np.allclose(X.mean(axis=0), np.zeros(X.shape[1]))
assert X.shape[1] == bas.n_basis_funcs

def test_conv_args_error(self):
with pytest.raises(ValueError, match="args should only be set"):
self.cls(5, 10, mode="eval")

def test_conv_kwargs_error(self):
with pytest.raises(ValueError, match="kwargs should only be set"):
self.cls(5, mode="eval", test="hi")
Expand Down Expand Up @@ -1015,10 +1011,6 @@ def test_identifiability_constraint_apply(self):
assert np.allclose(X.mean(axis=0), np.zeros(X.shape[1]))
assert X.shape[1] == bas.n_basis_funcs - 1

def test_conv_args_error(self):
with pytest.raises(ValueError, match="args should only be set"):
self.cls(5, 10, mode="eval")

def test_conv_kwargs_error(self):
with pytest.raises(ValueError, match="kwargs should only be set"):
self.cls(5, mode="eval", test="hi")
Expand Down Expand Up @@ -1509,10 +1501,6 @@ def test_identifiability_constraint_apply(self):
assert np.allclose(X.mean(axis=0), np.zeros(X.shape[1]))
assert X.shape[1] == bas.n_basis_funcs - 1

def test_conv_args_error(self):
with pytest.raises(ValueError, match="args should only be set"):
self.cls(5, 10, mode="eval")

def test_conv_kwargs_error(self):
with pytest.raises(ValueError, match="kwargs should only be set"):
self.cls(5, mode="eval", test="hi")
Expand Down Expand Up @@ -2066,15 +2054,10 @@ def test_identifiability_constraint_apply(self):
assert np.allclose(X.mean(axis=0), np.zeros(X.shape[1]))
assert X.shape[1] == bas.n_basis_funcs

def test_conv_args_error(self):
with pytest.raises(ValueError, match="args should only be set"):
self.cls(5, [1, 2, 3, 4, 5], 10, mode="eval")

def test_conv_kwargs_error(self):
with pytest.raises(ValueError, match="kwargs should only be set"):
self.cls(5, decay_rates=[1, 2, 3, 4, 5], mode="eval", test="hi")


def test_transformer_get_params(self):
bas = self.cls(5, decay_rates=[1, 2, 3, 4, 5])
bas_transformer = bas.to_transformer()
Expand Down Expand Up @@ -2507,10 +2490,6 @@ def test_identifiability_constraint_apply(self):
assert np.allclose(X.mean(axis=0), np.zeros(X.shape[1]))
assert X.shape[1] == bas.n_basis_funcs - 1

def test_conv_args_error(self):
with pytest.raises(ValueError, match="args should only be set"):
self.cls(5, 10, mode="eval")

def test_conv_kwargs_error(self):
with pytest.raises(ValueError, match="kwargs should only be set"):
self.cls(5, mode="eval", test="hi")
Expand Down Expand Up @@ -3039,10 +3018,6 @@ def test_identifiability_constraint_apply(self):
assert np.allclose(X.mean(axis=0), np.zeros(X.shape[1]))
assert X.shape[1] == bas.n_basis_funcs - 1

def test_conv_args_error(self):
with pytest.raises(ValueError, match="args should only be set"):
self.cls(5, 10, mode="eval")

def test_conv_kwargs_error(self):
with pytest.raises(ValueError, match="kwargs should only be set"):
self.cls(5, mode="eval", test="hi")
Expand Down Expand Up @@ -3749,12 +3724,8 @@ def test_compute_features_returns_expected_number_of_basis(
)
if eval_basis.shape[1] != basis_a_obj.n_basis_funcs * basis_b_obj.n_basis_funcs:
raise ValueError(
# <<<<<<< HEAD
"Dimensions do not agree: The number of basis should match the first dimension of the "
"fit_transformed basis."
# =======
# "Dimensions do not agree: The number of basis should match the first dimension of the output features."
# >>>>>>> development
f"The number of basis is {n_basis_a * n_basis_b}",
f"The first dimension of the output features is {eval_basis.shape[1]}",
)
Expand Down
Loading
Loading