Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add (Prox-)SVRG as a solver for GLMs #184

Merged
merged 198 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 194 commits
Commits
Show all changes
198 commits
Select commit Hold shift + click to select a range
f1de528
Start working on the Prox-SVRG solver
bagibence Jun 21, 2024
ad45bc8
Adapt SVRG and ProxSVRG for use with nemos and allow their use in reg…
bagibence Jun 24, 2024
1c7a9d9
Make ProxSVRG inherit from SVRG and implement early stopping
bagibence Jun 24, 2024
7f5d27b
Speed up (Prox)SVRG with jit, scan, while_loop
bagibence Jun 25, 2024
ef1983d
Remove unnecessary jits
bagibence Jun 25, 2024
3bfa251
Make prox required
bagibence Jun 25, 2024
df751dd
Simplify from scan to fori_loop in update
bagibence Jun 25, 2024
50ea86f
Tiny changes (whitespace and tree_zeros_like)
bagibence Jun 25, 2024
f016f10
Save the initial naive implementation of (Prox)SVRG in solvers_naive
bagibence Jun 25, 2024
56f49a2
Add batch_size and type annotations for (Prox)SVRG
bagibence Jun 25, 2024
47429f9
Rename lr to stepsize in (Prox)SVRG
bagibence Jun 25, 2024
d09fc8d
Move stepsize into state, redifine error, adjust m based on batch_size
bagibence Jun 26, 2024
901adba
Add a (commented out) update method that implements just the inner lo…
bagibence Jun 26, 2024
ac10fa4
Change SVRG.update to update parameters on every data point
bagibence Jun 27, 2024
4b4d3b5
Change epoch_num back to iter_num
bagibence Jun 27, 2024
8905008
Pass args to SVRG.init_state to calculate N, xs, df_xs
bagibence Jun 27, 2024
32dcf6c
Separately define run and update in SVRG and ProxSVRG
bagibence Jun 27, 2024
eb520f1
Fix wrong argument passed to solver.init_state
bagibence Jun 27, 2024
a25e9af
Switch inheritance between SVRG and ProxSVRG
bagibence Jun 27, 2024
4e89a5f
Store the solver in the regularizer
bagibence Jun 27, 2024
372b233
Minor renaming and sanity check
bagibence Jun 27, 2024
34ec663
Remove stale and unreachable code
bagibence Jun 27, 2024
1c483d0
Update solvers to work with mini-batch updates
bagibence Jun 28, 2024
30c7e59
Add x_av to SVRGState and use it in the return value of ProxSVRG.update
bagibence Jun 28, 2024
a48fe7e
Log the loss after each update
bagibence Jun 28, 2024
fa95695
Move the loss logging into the run method
bagibence Jun 28, 2024
1b855e5
Have separate _update_per_batch and _update_per_point methods
bagibence Jul 2, 2024
ccd554e
Add batch_size again and have different update methods
bagibence Jul 8, 2024
811e48b
Use different method for run based on batch size
bagibence Jul 8, 2024
f1618ae
Potentially pre-generate random indices instead of one-by-one
bagibence Jul 10, 2024
f7825cc
Add docstrings and clean up ProxSVRG
bagibence Jul 11, 2024
efc507f
Change how solvers are found: first look in solvers, then in jaxopt
bagibence Jul 11, 2024
e489a15
Switch to new-style RNG keys and remove the uint32 conversion
bagibence Jul 11, 2024
f6b579a
Change default convergence tolerance in (Prox)SVRG
bagibence Jul 12, 2024
3bce956
Add yet another way of defining the error
bagibence Jul 12, 2024
75ec116
Add missing import and change the error calculation to the previous one
bagibence Jul 12, 2024
4b6064f
Make parameters explicit where possible and remove indexing into the …
bagibence Jul 12, 2024
85b366d
added regr testing
BalzaniEdoardo Jul 12, 2024
599c390
merged conflicts
BalzaniEdoardo Jul 12, 2024
7d73a39
improved test data generation
BalzaniEdoardo Jul 12, 2024
212753f
linted conftest.py
BalzaniEdoardo Jul 12, 2024
280221c
added test for tree utils
BalzaniEdoardo Jul 12, 2024
a911fbe
updated description of tree_slice
BalzaniEdoardo Jul 12, 2024
cce74f2
add context manager to test
BalzaniEdoardo Jul 12, 2024
b3c0bf3
remove unused imports
BalzaniEdoardo Jul 12, 2024
41d5d01
uniform srtepsize
BalzaniEdoardo Jul 12, 2024
5759d18
drop unused line
BalzaniEdoardo Jul 12, 2024
1f3d659
FIXED TOL
BalzaniEdoardo Jul 12, 2024
6faa685
Reformat with black
bagibence Jul 15, 2024
66c73ac
Initialize loss_log with zeros instead of empty and run isort
bagibence Jul 15, 2024
ff09dce
Fix flake8's complaints
bagibence Jul 17, 2024
ae99cfc
Add a few tests for SVRG
bagibence Jul 18, 2024
ccbd41f
Add integration tests for (Prox)SVRG
bagibence Jul 19, 2024
b9644bc
Increment iter_num after each xk_update and comment out loss logging
bagibence Jul 19, 2024
549fdfc
Revert the iter_num increment frequency
bagibence Jul 19, 2024
4e66c9e
Test that GLM.fit runs with (Prox)SVRG
bagibence Jul 19, 2024
7ba0a3b
Test that calls to SVRG.update with a naive implementation of the res…
bagibence Jul 19, 2024
0816f13
Remove asserts and update test
bagibence Jul 19, 2024
be9ea0f
Change iter_num incrementation again and update the test for it
bagibence Jul 19, 2024
ee7b6e4
Change iter_num again to mean number of epochs and adjust test
bagibence Jul 19, 2024
bb3a2bf
Copy tree_utils from jaxopt
bagibence Jul 19, 2024
cb626a1
Add test for SVRG._xk_update
bagibence Jul 19, 2024
13f7b43
Use tree_slice of y as well
bagibence Jul 22, 2024
d97bd9b
Update tests
bagibence Jul 22, 2024
12fd6b7
fixed pop glm proximal gradient
BalzaniEdoardo Jul 22, 2024
32f8b02
Switch from X,y to *args in ProxSVRG. Leave prox_lamba explicit
bagibence Jul 23, 2024
6306313
Add ProxSVRG._run to handle prox_lambda better
bagibence Jul 23, 2024
953ef6f
fixing prox_lasso
BalzaniEdoardo Jul 23, 2024
9b3eea3
use our own lasso
BalzaniEdoardo Jul 23, 2024
93961be
added test for prox_lasso
BalzaniEdoardo Jul 23, 2024
5a10d8f
Test SVRG._xk_update with our own prox_lasso instead of jaxopt's
bagibence Jul 23, 2024
2344895
Test that GLM.fit converges to the same as sklearn's PoissonRegressor…
bagibence Jul 23, 2024
a4f588c
Remove prints and spaces
bagibence Jul 23, 2024
7838aab
Test that GLM.update with a naive SVRG outer loop gives the same as G…
bagibence Jul 23, 2024
c4e1a7a
Start working on the Prox-SVRG solver
bagibence Jun 21, 2024
f4dbb0e
Adapt SVRG and ProxSVRG for use with nemos and allow their use in reg…
bagibence Jun 24, 2024
6549266
Make ProxSVRG inherit from SVRG and implement early stopping
bagibence Jun 24, 2024
763be7b
Speed up (Prox)SVRG with jit, scan, while_loop
bagibence Jun 25, 2024
c967fb0
Remove unnecessary jits
bagibence Jun 25, 2024
5bb47cf
Make prox required
bagibence Jun 25, 2024
33d6c7b
Simplify from scan to fori_loop in update
bagibence Jun 25, 2024
4e2393d
Tiny changes (whitespace and tree_zeros_like)
bagibence Jun 25, 2024
25c44b9
Save the initial naive implementation of (Prox)SVRG in solvers_naive
bagibence Jun 25, 2024
0574d93
Add batch_size and type annotations for (Prox)SVRG
bagibence Jun 25, 2024
8278f45
Rename lr to stepsize in (Prox)SVRG
bagibence Jun 25, 2024
d881321
Move stepsize into state, redifine error, adjust m based on batch_size
bagibence Jun 26, 2024
3c2d560
Add a (commented out) update method that implements just the inner lo…
bagibence Jun 26, 2024
4c5bf27
Change SVRG.update to update parameters on every data point
bagibence Jun 27, 2024
245bdaa
Change epoch_num back to iter_num
bagibence Jun 27, 2024
6553737
Pass args to SVRG.init_state to calculate N, xs, df_xs
bagibence Jun 27, 2024
5f629a7
Separately define run and update in SVRG and ProxSVRG
bagibence Jun 27, 2024
6bcdb3f
Switch inheritance between SVRG and ProxSVRG
bagibence Jun 27, 2024
b0d8b93
Minor renaming and sanity check
bagibence Jun 27, 2024
1f12c46
Remove stale and unreachable code
bagibence Jun 27, 2024
e0a5d3f
Update solvers to work with mini-batch updates
bagibence Jun 28, 2024
26481e9
Add x_av to SVRGState and use it in the return value of ProxSVRG.update
bagibence Jun 28, 2024
6850f7e
Log the loss after each update
bagibence Jun 28, 2024
7461caf
Move the loss logging into the run method
bagibence Jun 28, 2024
d9539d4
Have separate _update_per_batch and _update_per_point methods
bagibence Jul 2, 2024
e068b3b
Add batch_size again and have different update methods
bagibence Jul 8, 2024
e26af13
Use different method for run based on batch size
bagibence Jul 8, 2024
ec24269
Potentially pre-generate random indices instead of one-by-one
bagibence Jul 10, 2024
ebfe409
Add docstrings and clean up ProxSVRG
bagibence Jul 11, 2024
b52b71d
Switch to new-style RNG keys and remove the uint32 conversion
bagibence Jul 11, 2024
7c8e7d5
Change default convergence tolerance in (Prox)SVRG
bagibence Jul 12, 2024
a599cd9
Add yet another way of defining the error
bagibence Jul 12, 2024
d94edf5
Add missing import and change the error calculation to the previous one
bagibence Jul 12, 2024
5343a15
added regr testing
BalzaniEdoardo Jul 12, 2024
2044c3f
Make parameters explicit where possible and remove indexing into the …
bagibence Jul 12, 2024
5c600db
improved test data generation
BalzaniEdoardo Jul 12, 2024
cd47403
linted conftest.py
BalzaniEdoardo Jul 12, 2024
a668acf
added test for tree utils
BalzaniEdoardo Jul 12, 2024
00df370
updated description of tree_slice
BalzaniEdoardo Jul 12, 2024
761ab73
add context manager to test
BalzaniEdoardo Jul 12, 2024
f0be599
remove unused imports
BalzaniEdoardo Jul 12, 2024
194b432
uniform srtepsize
BalzaniEdoardo Jul 12, 2024
ecd4ef4
drop unused line
BalzaniEdoardo Jul 12, 2024
e17a495
FIXED TOL
BalzaniEdoardo Jul 12, 2024
d10e507
Initialize loss_log with zeros instead of empty and run isort
bagibence Jul 15, 2024
41b627f
Fix flake8's complaints
bagibence Jul 17, 2024
bc32794
Add a few tests for SVRG
bagibence Jul 18, 2024
846ed43
Add integration tests for (Prox)SVRG
bagibence Jul 19, 2024
8012176
Increment iter_num after each xk_update and comment out loss logging
bagibence Jul 19, 2024
a75f55b
Revert the iter_num increment frequency
bagibence Jul 19, 2024
7fb3d9d
Test that GLM.fit runs with (Prox)SVRG
bagibence Jul 19, 2024
f751c76
Test that calls to SVRG.update with a naive implementation of the res…
bagibence Jul 19, 2024
ad118c7
Remove asserts and update test
bagibence Jul 19, 2024
21273eb
Change iter_num incrementation again and update the test for it
bagibence Jul 19, 2024
4d8fb3f
Change iter_num again to mean number of epochs and adjust test
bagibence Jul 19, 2024
adbe0a2
Copy tree_utils from jaxopt
bagibence Jul 19, 2024
56f1ff3
Add test for SVRG._xk_update
bagibence Jul 19, 2024
8f743df
Use tree_slice of y as well
bagibence Jul 22, 2024
1ca062f
Update tests
bagibence Jul 22, 2024
f51f6d2
fixed pop glm proximal gradient
BalzaniEdoardo Jul 22, 2024
3004c14
Switch from X,y to *args in ProxSVRG. Leave prox_lamba explicit
bagibence Jul 23, 2024
eac538e
Add ProxSVRG._run to handle prox_lambda better
bagibence Jul 23, 2024
0d673cc
fixing prox_lasso
BalzaniEdoardo Jul 23, 2024
457397e
use our own lasso
BalzaniEdoardo Jul 23, 2024
d58cb16
added test for prox_lasso
BalzaniEdoardo Jul 23, 2024
2ae4a13
Test SVRG._xk_update with our own prox_lasso instead of jaxopt's
bagibence Jul 23, 2024
a885cda
Test that GLM.fit converges to the same as sklearn's PoissonRegressor…
bagibence Jul 23, 2024
3de8780
Test that GLM.update with a naive SVRG outer loop gives the same as G…
bagibence Jul 23, 2024
231c0a6
Bit of tidying up in ProxSVRG
bagibence Jul 26, 2024
82fa1a6
Make BaseRegressor aware of nemos.solvers
bagibence Jul 26, 2024
efb97c9
Allow using SVRG and ProxSVRG with regularizers
bagibence Jul 26, 2024
83b093c
Update tests
bagibence Jul 26, 2024
7b0ff57
reconciled history
BalzaniEdoardo Jul 26, 2024
06ed5c7
updated reg
BalzaniEdoardo Jul 26, 2024
cc11af4
fixed error message and test
BalzaniEdoardo Jul 26, 2024
710f484
linting (with partial flake8 on test)
BalzaniEdoardo Jul 26, 2024
9fffdd1
bugfixed mask creation
BalzaniEdoardo Jul 26, 2024
52e9c2e
fixed ellipsis
BalzaniEdoardo Jul 26, 2024
500760a
fixed flake8
BalzaniEdoardo Jul 26, 2024
81a5f0f
removed ellipsis
BalzaniEdoardo Jul 26, 2024
cf56b89
removed ellipsis
BalzaniEdoardo Jul 26, 2024
ff9b9b9
Only store the loss function in BaseRegressor, not the solver
bagibence Jul 26, 2024
2dbe45a
Remove unused solvers_naive.py
bagibence Aug 2, 2024
1f44561
Remove logging the loss in ProxSVRG
bagibence Aug 2, 2024
f327cac
Remove x_av and x_sum in ProxSVRG
bagibence Aug 2, 2024
3f39005
Change default random key from 0 to 123
bagibence Aug 2, 2024
e54726b
Remove unused error calculations and add docstring for the used one i…
bagibence Aug 2, 2024
6926bdc
Apply docstring spelling suggestions from code review
bagibence Aug 2, 2024
cb610dd
Start improving citations
bagibence Aug 5, 2024
4a66ac7
Correct references
bagibence Aug 5, 2024
e8ddf3f
Remove unused imports
bagibence Aug 5, 2024
01a39b4
Remove iteration through args when using `tree_slice`
bagibence Aug 5, 2024
335860a
Remove type hints from the docstrings
bagibence Aug 5, 2024
666afa2
Update description of *args in ProxSVRG methods
bagibence Aug 5, 2024
46653f9
Type hint parameters as Pytree
bagibence Aug 5, 2024
d3a904f
... finishing the previous commit
bagibence Aug 5, 2024
0285b78
Add "Raises" sections in (Prox)SVRG
bagibence Aug 5, 2024
2bad7ac
Correct default random key in SVRG tests
bagibence Aug 5, 2024
e1293f4
Attempt to use more descriptive variable names
bagibence Aug 5, 2024
167c8dd
Fix example formatting in (Prox)SVRG docstrings
bagibence Aug 6, 2024
ca354fe
Make (Prox)SVRG.update docstring more precise and generic
bagibence Aug 6, 2024
41aa456
Remove duplicate comment
bagibence Aug 6, 2024
bdfc35b
Add docstring for SVRGState and update references
bagibence Aug 7, 2024
e5904aa
renamed fileds
BalzaniEdoardo Aug 8, 2024
214c9da
Change variable naming to be more descriptive
bagibence Aug 8, 2024
ce50fbd
Fix references in the docstrings
bagibence Aug 8, 2024
bfc4cfc
Remove init_full_gradient argument from SVRG.init_state
bagibence Aug 8, 2024
46fe498
merged development
BalzaniEdoardo Aug 10, 2024
cd4660b
flake8 fixes
BalzaniEdoardo Aug 10, 2024
61478a4
merged with origin
BalzaniEdoardo Aug 10, 2024
8c74bd9
fixed a deprecation warning
BalzaniEdoardo Aug 10, 2024
5f33bd6
merged dev
BalzaniEdoardo Aug 13, 2024
493f090
removed unused arg
BalzaniEdoardo Aug 13, 2024
3987254
Update src/nemos/solvers.py
BalzaniEdoardo Aug 13, 2024
0c3aabb
removed line
BalzaniEdoardo Aug 13, 2024
ca0b51a
partial solution to references
BalzaniEdoardo Aug 13, 2024
4eb3925
fix references
BalzaniEdoardo Aug 14, 2024
bdb1d3d
linted
BalzaniEdoardo Aug 14, 2024
1537645
merged
BalzaniEdoardo Aug 14, 2024
d39ea73
Merge branch 'development' into svrg
BalzaniEdoardo Aug 14, 2024
f5528bc
Update src/nemos/solvers.py
BalzaniEdoardo Aug 15, 2024
0ead1bf
Merge branch 'development' into svrg
BalzaniEdoardo Aug 15, 2024
a318e33
improved docstrings
BalzaniEdoardo Aug 15, 2024
0d539d2
linted
BalzaniEdoardo Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions src/nemos/base_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import jaxopt
from numpy.typing import ArrayLike, NDArray

from . import utils, validation
from . import solvers, utils, validation
from ._regularizer_builder import AVAILABLE_REGULARIZERS, create_regularizer
from .base_class import Base
from .regularizer import Regularizer, UnRegularized
Expand Down Expand Up @@ -218,18 +218,20 @@ def solver_kwargs(self):
@solver_kwargs.setter
def solver_kwargs(self, solver_kwargs: dict):
"""Setter for the solver_kwargs attribute."""
self._check_solver_kwargs(self.solver_name, solver_kwargs)
self._check_solver_kwargs(
self._get_solver_class(self.solver_name), solver_kwargs
)
self._solver_kwargs = solver_kwargs

@staticmethod
def _check_solver_kwargs(solver_name, solver_kwargs):
def _check_solver_kwargs(solver_class, solver_kwargs):
"""
Check if provided solver keyword arguments are valid.

Parameters
----------
solver_name :
Name of the solver.
solver_class :
Class of the solver.
solver_kwargs :
Additional keyword arguments for the solver.

Expand All @@ -238,11 +240,11 @@ def _check_solver_kwargs(solver_name, solver_kwargs):
NameError
If any of the solver keyword arguments are not valid.
"""
solver_args = inspect.getfullargspec(getattr(jaxopt, solver_name)).args
solver_args = inspect.getfullargspec(solver_class).args
undefined_kwargs = set(solver_kwargs.keys()).difference(solver_args)
if undefined_kwargs:
raise NameError(
f"kwargs {undefined_kwargs} in solver_kwargs not a kwarg for jaxopt.{solver_name}!"
f"kwargs {undefined_kwargs} in solver_kwargs not a kwarg for {solver_class.__name__}!"
)

def instantiate_solver(self, *args) -> BaseRegressor:
Expand All @@ -253,10 +255,10 @@ def instantiate_solver(self, *args) -> BaseRegressor:
that initialize the solver state, update the model parameters, and run the optimization
as attributes.

This method creates a solver instance from jaxopt library, tailored to the specific loss
function and regularization approach defined by the Regularizer instance. It also handles
the proximal operator if required for the optimization method. The returned functions are
directly usable in optimization loops, simplifying the syntax by pre-setting
This method creates a solver instance from nemos.solvers or the jaxopt library, tailored to
the specific loss function and regularization approach defined by the Regularizer instance.
It also handles the proximal operator if required for the optimization method. The returned
functions are directly usable in optimization loops, simplifying the syntax by pre-setting
common arguments like regularization strength and other hyperparameters.

Parameters
Expand All @@ -281,7 +283,7 @@ def instantiate_solver(self, *args) -> BaseRegressor:
# only use penalized loss if not using proximal gradient descent
# In proximal method you must use the unpenalized loss independently
# of what regularizer you are using.
if self.solver_name != "ProximalGradient":
if self.solver_name not in ("ProximalGradient", "ProxSVRG"):
loss = self.regularizer.penalized_loss(
self._predict_and_compute_loss, self.regularizer_strength
)
Expand All @@ -295,7 +297,7 @@ def instantiate_solver(self, *args) -> BaseRegressor:
utils.assert_is_callable(loss, "loss")

# some parsing to make sure solver gets instantiated properly
if self.solver_name == "ProximalGradient":
if self.solver_name in ("ProximalGradient", "ProxSVRG"):
if "prox" in self.solver_kwargs:
raise ValueError(
"Proximal operator specification is not permitted. "
Expand All @@ -315,7 +317,11 @@ def instantiate_solver(self, *args) -> BaseRegressor:
) = self._inspect_solver_kwargs(solver_kwargs)

# instantiate the solver
solver = getattr(jaxopt, self.solver_name)(fun=loss, **solver_init_kwargs)
solver = self._get_solver_class(self.solver_name)(
fun=loss, **solver_init_kwargs
)

self._solver_loss_fun_ = loss

def solver_run(
init_params: Tuple[DESIGN_INPUT_TYPE, jnp.ndarray], *run_args: jnp.ndarray
Expand All @@ -327,10 +333,9 @@ def solver_update(params, state, *run_args, **run_kwargs) -> jaxopt.OptStep:
params, state, *args, *run_args, **solver_update_kwargs, **run_kwargs
)

def solver_init_state(params, state, *run_args, **run_kwargs) -> NamedTuple:
def solver_init_state(params, *run_args, **run_kwargs) -> NamedTuple:
return solver.init_state(
params,
state,
*run_args,
**run_kwargs,
**solver_init_state_kwargs,
Expand Down Expand Up @@ -372,7 +377,7 @@ def _inspect_solver_kwargs(

if solver_kwargs:
# instantiate a solver to then inspect the params of its various functions
solver = getattr(jaxopt, self.solver_name)
solver = self._get_solver_class(self.solver_name)

for key, value in solver_kwargs.items():
if key in inspect.getfullargspec(solver.run).args:
Expand Down Expand Up @@ -540,3 +545,35 @@ def initialize_state(
) -> Union[Any, NamedTuple]:
"""Initialize the state of the solver for running fit and update."""
pass

@staticmethod
def _get_solver_class(solver_name: str):
"""
Find a solver class first looking in nemos.solvers, then in jaxopt.

Parameters
----------
solver_name : str
Name of the solver class to load.

Returns
-------
solver_class :
Solver class ready to be instantiated.

Raises
------
AttributeError
If a solver class with that name is not found.
"""
try:
solver_class = getattr(solvers, solver_name)
except AttributeError:
try:
solver_class = getattr(jaxopt, solver_name)
except AttributeError:
raise AttributeError(
f"Could not find {solver_name} in nemos.solvers or jaxopt"
)

return solver_class
30 changes: 15 additions & 15 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,7 @@ def _check_n_basis_min(self) -> None:

class MSplineBasis(SplineBasis):
r"""
M-spline[$^1$](#references) basis functions for modeling and data transformation.
M-spline[$^{[1]}$](#references) basis functions for modeling and data transformation.

M-splines are a type of spline basis function used for smooth curve fitting
and data representation. They are positive and integrate to one, making them
Expand Down Expand Up @@ -1394,8 +1394,8 @@ class MSplineBasis(SplineBasis):
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = mspline_basis(sample_points)

References
----------
# References
------------
[1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science,
3(4), 425-441.

Expand Down Expand Up @@ -1517,7 +1517,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:

class BSplineBasis(SplineBasis):
"""
B-spline[$^1$](#references) 1-dimensional basis functions.
B-spline[$^{[1]}$](#references) 1-dimensional basis functions.

Parameters
----------
Expand Down Expand Up @@ -1546,9 +1546,9 @@ class BSplineBasis(SplineBasis):
Spline order.


References
----------
1. Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
# References
------------
[1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5

"""
Expand Down Expand Up @@ -1779,7 +1779,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
class RaisedCosineBasisLinear(Basis):
"""Represent linearly-spaced raised cosine basis functions.

This implementation is based on the cosine bumps used by Pillow et al.[$^1$](#references)
This implementation is based on the cosine bumps used by Pillow et al.[$^{[1]}$](#references)
to uniformly tile the internal points of the domain.

Parameters
Expand All @@ -1801,9 +1801,9 @@ class RaisedCosineBasisLinear(Basis):
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`

References
----------
1. Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
# References
------------
[1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
C. E. (2005). Prediction and decoding of retinal ganglion cell responses
with a probabilistic spiking model. Journal of Neuroscience, 25(47),
11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005
Expand Down Expand Up @@ -1964,7 +1964,7 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear):
"""Represent log-spaced raised cosine basis functions.

Similar to `RaisedCosineBasisLinear` but the basis functions are log-spaced.
This implementation is based on the cosine bumps used by Pillow et al.[$^1$](#references)
This implementation is based on the cosine bumps used by Pillow et al.[$^{[1]}$](#references)
to uniformly tile the internal points of the domain.

Parameters
Expand Down Expand Up @@ -1994,9 +1994,9 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear):
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`

References
----------
1. Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
# References
------------
[1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
C. E. (2005). Prediction and decoding of retinal ganglion cell responses
with a probabilistic spiking model. Journal of Neuroscience, 25(47),
11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005
Expand Down
35 changes: 18 additions & 17 deletions src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,19 +622,6 @@ def fit(
else:
data = X

# check if mask has been set is using group lasso
# if mask has not been set, use a single group as default
if isinstance(self.regularizer, GroupLasso):
if self.regularizer.mask is None:
warnings.warn(
UserWarning(
"Mask has not been set. Defaulting to a single group for all parameters. "
"Please see the documentation on GroupLasso regularization for defining a "
"mask."
)
)
self.regularizer.mask = jnp.ones((1, data.shape[1]))

self.initialize_state(data, y, init_params)

params, state = self.solver_run(init_params, data, y)
Expand Down Expand Up @@ -882,13 +869,27 @@ def initialize_state(
NamedTuple
The initialized solver state
"""
# set up the solver init/run/update attrs
self.instantiate_solver()

if isinstance(X, FeaturePytree):
data = X.data
else:
data = X

# check if mask has been set is using group lasso
# if mask has not been set, use a single group as default
if isinstance(self.regularizer, GroupLasso):
if self.regularizer.mask is None:
warnings.warn(
UserWarning(
"Mask has not been set. Defaulting to a single group for all parameters. "
"Please see the documentation on GroupLasso regularization for defining a "
"mask."
)
)
self.regularizer.mask = jnp.ones((1, data.shape[1]))

# set up the solver init/run/update attrs
self.instantiate_solver()

opt_state = self.solver_init_state(init_params, data, y)
return opt_state

Expand Down Expand Up @@ -1311,7 +1312,7 @@ def _check_mask(self, X, y, params):
axis_2=1,
err_message="Inconsistent number of neurons. "
f"feature_mask has {jax.tree_util.tree_map(lambda m: m.shape[neural_axis], self.feature_mask)} neurons, "
f"model coefficients have {jax.tree_util.tree_map(lambda x: x.shape[1], X)} instead!",
f"model coefficients have {jax.tree_util.tree_map(lambda x: x.shape[1], params[0])} instead!",
)

@cast_to_jax
Expand Down
14 changes: 7 additions & 7 deletions src/nemos/observation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ def pseudo_r2(
) -> jnp.ndarray:
r"""Pseudo-$R^2$ calculation for a GLM.

Compute the pseudo-$R^2$ metric for the GLM, as defined by McFadden et al.[$^1$](#references)
or by Cohen et al.[$^2$](#references).
Compute the pseudo-$R^2$ metric for the GLM, as defined by McFadden et al.[$^{[1]}$](#references)
or by Cohen et al.[$^{[2]}$](#references).

This metric evaluates the goodness-of-fit of the model relative to a null (baseline) model that assumes a
constant mean for the observations. While the pseudo-$R^2$ is bounded between 0 and 1 for the training set,
Expand Down Expand Up @@ -311,13 +311,13 @@ def pseudo_r2(
sample, i.e. the maximum value that the likelihood could possibly achieve). $D_M$ and $D_0$ are
the model and the null deviance, $D_i = -2 \left[ \log(L_s) - \log(L_i) \right]$ for $i=M,0$.


References
----------
1. McFadden D (1979). Quantitative methods for analysing travel behavior of individuals: Some recent
# References
------------
[1] McFadden D (1979). Quantitative methods for analysing travel behavior of individuals: Some recent
developments. In D. A. Hensher & P. R. Stopher (Eds.), *Behavioural travel modelling* (pp. 279-318).
London: Croom Helm.
2. Jacob Cohen, Patricia Cohen, Steven G. West, Leona S. Aiken.

[2] Jacob Cohen, Patricia Cohen, Steven G. West, Leona S. Aiken.
*Applied Multiple Regression/Correlation Analysis for the Behavioral Sciences*.
3rd edition. Routledge, 2002. p.502. ISBN 978-0-8058-2223-6. (May 2012)
"""
Expand Down
43 changes: 41 additions & 2 deletions src/nemos/proximal_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
[1] Parikh, Neal, and Stephen Boyd. *"Proximal Algorithms, ser. Foundations and Trends (r) in Optimization."* (2013).
"""

from typing import Tuple
from typing import Any, Optional, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -132,6 +132,7 @@ def prox_group_lasso(

"""
weights, intercepts = params
shape = weights.shape
# divide the reg strength by the number of neurons
regularizer_strength /= intercepts.shape[0]
# add an extra dim if not 2D, do nothing otherwise.
Expand All @@ -143,4 +144,42 @@ def prox_group_lasso(
# Avoid shrinkage of features that do not belong to any group
# by setting the shrinkage factor to 1.
not_regularized = jnp.outer(jnp.ones(factor.shape[0]), 1 - mask.sum(axis=0))
return jnp.squeeze(weights * (factor @ mask + not_regularized)).T, intercepts
return (weights * (factor @ mask + not_regularized)).T.reshape(shape), intercepts


def prox_lasso(x: Any, l1reg: Optional[Any] = None, scaling: float = 1.0) -> Any:
r"""Proximal operator for the l1 norm, i.e., soft-thresholding operator.

Minimizes the following function:

$$
\underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||\_2^2
+ \text{scaling} \cdot \text{l1reg} \cdot ||y||\_1
$$

When `l1reg` is a pytree, the weights are applied coordinate-wise.

Parameters
----------
x :
Input pytree.
l1reg :
Regularization strength, float or pytree with the same structure as `x`. Default is None.
scaling : float, optional
A scaling factor. Default is 1.0.

Returns
-------
:
Output pytree with the same structure as `x`.
"""
if l1reg is None:
l1reg = 1.0

if jnp.isscalar(l1reg):
l1reg = jax.tree_util.tree_map(lambda y: l1reg * jnp.ones_like(y), x)

def fun(u, v):
return jnp.sign(u) * jax.nn.relu(jnp.abs(u) - v * scaling)

return jax.tree_util.tree_map(fun, x, l1reg)
Loading
Loading