diff --git a/src/nemos/base_regressor.py b/src/nemos/base_regressor.py index 0e7581a7..5f651313 100644 --- a/src/nemos/base_regressor.py +++ b/src/nemos/base_regressor.py @@ -14,7 +14,7 @@ import jaxopt from numpy.typing import ArrayLike, NDArray -from . import utils, validation +from . import solvers, utils, validation from ._regularizer_builder import AVAILABLE_REGULARIZERS, create_regularizer from .base_class import Base from .regularizer import Regularizer, UnRegularized @@ -218,18 +218,20 @@ def solver_kwargs(self): @solver_kwargs.setter def solver_kwargs(self, solver_kwargs: dict): """Setter for the solver_kwargs attribute.""" - self._check_solver_kwargs(self.solver_name, solver_kwargs) + self._check_solver_kwargs( + self._get_solver_class(self.solver_name), solver_kwargs + ) self._solver_kwargs = solver_kwargs @staticmethod - def _check_solver_kwargs(solver_name, solver_kwargs): + def _check_solver_kwargs(solver_class, solver_kwargs): """ Check if provided solver keyword arguments are valid. Parameters ---------- - solver_name : - Name of the solver. + solver_class : + Class of the solver. solver_kwargs : Additional keyword arguments for the solver. @@ -238,11 +240,11 @@ def _check_solver_kwargs(solver_name, solver_kwargs): NameError If any of the solver keyword arguments are not valid. """ - solver_args = inspect.getfullargspec(getattr(jaxopt, solver_name)).args + solver_args = inspect.getfullargspec(solver_class).args undefined_kwargs = set(solver_kwargs.keys()).difference(solver_args) if undefined_kwargs: raise NameError( - f"kwargs {undefined_kwargs} in solver_kwargs not a kwarg for jaxopt.{solver_name}!" + f"kwargs {undefined_kwargs} in solver_kwargs not a kwarg for {solver_class.__name__}!" ) def instantiate_solver(self, *args) -> BaseRegressor: @@ -253,10 +255,10 @@ def instantiate_solver(self, *args) -> BaseRegressor: that initialize the solver state, update the model parameters, and run the optimization as attributes. - This method creates a solver instance from jaxopt library, tailored to the specific loss - function and regularization approach defined by the Regularizer instance. It also handles - the proximal operator if required for the optimization method. The returned functions are - directly usable in optimization loops, simplifying the syntax by pre-setting + This method creates a solver instance from nemos.solvers or the jaxopt library, tailored to + the specific loss function and regularization approach defined by the Regularizer instance. + It also handles the proximal operator if required for the optimization method. The returned + functions are directly usable in optimization loops, simplifying the syntax by pre-setting common arguments like regularization strength and other hyperparameters. Parameters @@ -281,7 +283,7 @@ def instantiate_solver(self, *args) -> BaseRegressor: # only use penalized loss if not using proximal gradient descent # In proximal method you must use the unpenalized loss independently # of what regularizer you are using. - if self.solver_name != "ProximalGradient": + if self.solver_name not in ("ProximalGradient", "ProxSVRG"): loss = self.regularizer.penalized_loss( self._predict_and_compute_loss, self.regularizer_strength ) @@ -295,7 +297,7 @@ def instantiate_solver(self, *args) -> BaseRegressor: utils.assert_is_callable(loss, "loss") # some parsing to make sure solver gets instantiated properly - if self.solver_name == "ProximalGradient": + if self.solver_name in ("ProximalGradient", "ProxSVRG"): if "prox" in self.solver_kwargs: raise ValueError( "Proximal operator specification is not permitted. " @@ -315,7 +317,11 @@ def instantiate_solver(self, *args) -> BaseRegressor: ) = self._inspect_solver_kwargs(solver_kwargs) # instantiate the solver - solver = getattr(jaxopt, self.solver_name)(fun=loss, **solver_init_kwargs) + solver = self._get_solver_class(self.solver_name)( + fun=loss, **solver_init_kwargs + ) + + self._solver_loss_fun_ = loss def solver_run( init_params: Tuple[DESIGN_INPUT_TYPE, jnp.ndarray], *run_args: jnp.ndarray @@ -327,10 +333,9 @@ def solver_update(params, state, *run_args, **run_kwargs) -> jaxopt.OptStep: params, state, *args, *run_args, **solver_update_kwargs, **run_kwargs ) - def solver_init_state(params, state, *run_args, **run_kwargs) -> NamedTuple: + def solver_init_state(params, *run_args, **run_kwargs) -> NamedTuple: return solver.init_state( params, - state, *run_args, **run_kwargs, **solver_init_state_kwargs, @@ -372,7 +377,7 @@ def _inspect_solver_kwargs( if solver_kwargs: # instantiate a solver to then inspect the params of its various functions - solver = getattr(jaxopt, self.solver_name) + solver = self._get_solver_class(self.solver_name) for key, value in solver_kwargs.items(): if key in inspect.getfullargspec(solver.run).args: @@ -540,3 +545,35 @@ def initialize_state( ) -> Union[Any, NamedTuple]: """Initialize the state of the solver for running fit and update.""" pass + + @staticmethod + def _get_solver_class(solver_name: str): + """ + Find a solver class first looking in nemos.solvers, then in jaxopt. + + Parameters + ---------- + solver_name : str + Name of the solver class to load. + + Returns + ------- + solver_class : + Solver class ready to be instantiated. + + Raises + ------ + AttributeError + If a solver class with that name is not found. + """ + try: + solver_class = getattr(solvers, solver_name) + except AttributeError: + try: + solver_class = getattr(jaxopt, solver_name) + except AttributeError: + raise AttributeError( + f"Could not find {solver_name} in nemos.solvers or jaxopt" + ) + + return solver_class diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 881ea394..ba9d42be 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -1350,7 +1350,7 @@ def _check_n_basis_min(self) -> None: class MSplineBasis(SplineBasis): r""" - M-spline[$^1$](#references) basis functions for modeling and data transformation. + M-spline[$^{[1]}$](#references) basis functions for modeling and data transformation. M-splines are a type of spline basis function used for smooth curve fitting and data representation. They are positive and integrate to one, making them @@ -1394,8 +1394,8 @@ class MSplineBasis(SplineBasis): >>> sample_points = linspace(0, 1, 100) >>> basis_functions = mspline_basis(sample_points) - References - ---------- + # References + ------------ [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, 3(4), 425-441. @@ -1517,7 +1517,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class BSplineBasis(SplineBasis): """ - B-spline[$^1$](#references) 1-dimensional basis functions. + B-spline[$^{[1]}$](#references) 1-dimensional basis functions. Parameters ---------- @@ -1546,9 +1546,9 @@ class BSplineBasis(SplineBasis): Spline order. - References - ---------- - 1. Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. + # References + ------------ + [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 """ @@ -1779,7 +1779,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class RaisedCosineBasisLinear(Basis): """Represent linearly-spaced raised cosine basis functions. - This implementation is based on the cosine bumps used by Pillow et al.[$^1$](#references) + This implementation is based on the cosine bumps used by Pillow et al.[$^{[1]}$](#references) to uniformly tile the internal points of the domain. Parameters @@ -1801,9 +1801,9 @@ class RaisedCosineBasisLinear(Basis): Only used in "conv" mode. Additional keyword arguments that are passed to `nemos.convolve.create_convolutional_predictor` - References - ---------- - 1. Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + # References + ------------ + [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., C. E. (2005). Prediction and decoding of retinal ganglion cell responses with a probabilistic spiking model. Journal of Neuroscience, 25(47), 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 @@ -1964,7 +1964,7 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): """Represent log-spaced raised cosine basis functions. Similar to `RaisedCosineBasisLinear` but the basis functions are log-spaced. - This implementation is based on the cosine bumps used by Pillow et al.[$^1$](#references) + This implementation is based on the cosine bumps used by Pillow et al.[$^{[1]}$](#references) to uniformly tile the internal points of the domain. Parameters @@ -1994,9 +1994,9 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): Only used in "conv" mode. Additional keyword arguments that are passed to `nemos.convolve.create_convolutional_predictor` - References - ---------- - 1. Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + # References + ------------ + [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., C. E. (2005). Prediction and decoding of retinal ganglion cell responses with a probabilistic spiking model. Journal of Neuroscience, 25(47), 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 5c65ef81..8688257f 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -622,19 +622,6 @@ def fit( else: data = X - # check if mask has been set is using group lasso - # if mask has not been set, use a single group as default - if isinstance(self.regularizer, GroupLasso): - if self.regularizer.mask is None: - warnings.warn( - UserWarning( - "Mask has not been set. Defaulting to a single group for all parameters. " - "Please see the documentation on GroupLasso regularization for defining a " - "mask." - ) - ) - self.regularizer.mask = jnp.ones((1, data.shape[1])) - self.initialize_state(data, y, init_params) params, state = self.solver_run(init_params, data, y) @@ -882,13 +869,27 @@ def initialize_state( NamedTuple The initialized solver state """ - # set up the solver init/run/update attrs - self.instantiate_solver() - if isinstance(X, FeaturePytree): data = X.data else: data = X + + # check if mask has been set is using group lasso + # if mask has not been set, use a single group as default + if isinstance(self.regularizer, GroupLasso): + if self.regularizer.mask is None: + warnings.warn( + UserWarning( + "Mask has not been set. Defaulting to a single group for all parameters. " + "Please see the documentation on GroupLasso regularization for defining a " + "mask." + ) + ) + self.regularizer.mask = jnp.ones((1, data.shape[1])) + + # set up the solver init/run/update attrs + self.instantiate_solver() + opt_state = self.solver_init_state(init_params, data, y) return opt_state @@ -1311,7 +1312,7 @@ def _check_mask(self, X, y, params): axis_2=1, err_message="Inconsistent number of neurons. " f"feature_mask has {jax.tree_util.tree_map(lambda m: m.shape[neural_axis], self.feature_mask)} neurons, " - f"model coefficients have {jax.tree_util.tree_map(lambda x: x.shape[1], X)} instead!", + f"model coefficients have {jax.tree_util.tree_map(lambda x: x.shape[1], params[0])} instead!", ) @cast_to_jax diff --git a/src/nemos/observation_models.py b/src/nemos/observation_models.py index 18243b09..61ed9381 100644 --- a/src/nemos/observation_models.py +++ b/src/nemos/observation_models.py @@ -267,8 +267,8 @@ def pseudo_r2( ) -> jnp.ndarray: r"""Pseudo-$R^2$ calculation for a GLM. - Compute the pseudo-$R^2$ metric for the GLM, as defined by McFadden et al.[$^1$](#references) - or by Cohen et al.[$^2$](#references). + Compute the pseudo-$R^2$ metric for the GLM, as defined by McFadden et al.[$^{[1]}$](#references) + or by Cohen et al.[$^{[2]}$](#references). This metric evaluates the goodness-of-fit of the model relative to a null (baseline) model that assumes a constant mean for the observations. While the pseudo-$R^2$ is bounded between 0 and 1 for the training set, @@ -311,13 +311,13 @@ def pseudo_r2( sample, i.e. the maximum value that the likelihood could possibly achieve). $D_M$ and $D_0$ are the model and the null deviance, $D_i = -2 \left[ \log(L_s) - \log(L_i) \right]$ for $i=M,0$. - - References - ---------- - 1. McFadden D (1979). Quantitative methods for analysing travel behavior of individuals: Some recent + # References + ------------ + [1] McFadden D (1979). Quantitative methods for analysing travel behavior of individuals: Some recent developments. In D. A. Hensher & P. R. Stopher (Eds.), *Behavioural travel modelling* (pp. 279-318). London: Croom Helm. - 2. Jacob Cohen, Patricia Cohen, Steven G. West, Leona S. Aiken. + + [2] Jacob Cohen, Patricia Cohen, Steven G. West, Leona S. Aiken. *Applied Multiple Regression/Correlation Analysis for the Behavioral Sciences*. 3rd edition. Routledge, 2002. p.502. ISBN 978-0-8058-2223-6. (May 2012) """ diff --git a/src/nemos/proximal_operator.py b/src/nemos/proximal_operator.py index e879f1f2..70a69bfa 100644 --- a/src/nemos/proximal_operator.py +++ b/src/nemos/proximal_operator.py @@ -24,7 +24,7 @@ [1] Parikh, Neal, and Stephen Boyd. *"Proximal Algorithms, ser. Foundations and Trends (r) in Optimization."* (2013). """ -from typing import Tuple +from typing import Any, Optional, Tuple import jax import jax.numpy as jnp @@ -132,6 +132,7 @@ def prox_group_lasso( """ weights, intercepts = params + shape = weights.shape # divide the reg strength by the number of neurons regularizer_strength /= intercepts.shape[0] # add an extra dim if not 2D, do nothing otherwise. @@ -143,4 +144,42 @@ def prox_group_lasso( # Avoid shrinkage of features that do not belong to any group # by setting the shrinkage factor to 1. not_regularized = jnp.outer(jnp.ones(factor.shape[0]), 1 - mask.sum(axis=0)) - return jnp.squeeze(weights * (factor @ mask + not_regularized)).T, intercepts + return (weights * (factor @ mask + not_regularized)).T.reshape(shape), intercepts + + +def prox_lasso(x: Any, l1reg: Optional[Any] = None, scaling: float = 1.0) -> Any: + r"""Proximal operator for the l1 norm, i.e., soft-thresholding operator. + + Minimizes the following function: + + $$ + \underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||\_2^2 + + \text{scaling} \cdot \text{l1reg} \cdot ||y||\_1 + $$ + + When `l1reg` is a pytree, the weights are applied coordinate-wise. + + Parameters + ---------- + x : + Input pytree. + l1reg : + Regularization strength, float or pytree with the same structure as `x`. Default is None. + scaling : float, optional + A scaling factor. Default is 1.0. + + Returns + ------- + : + Output pytree with the same structure as `x`. + """ + if l1reg is None: + l1reg = 1.0 + + if jnp.isscalar(l1reg): + l1reg = jax.tree_util.tree_map(lambda y: l1reg * jnp.ones_like(y), x) + + def fun(u, v): + return jnp.sign(u) * jax.nn.relu(jnp.abs(u) - v * scaling) + + return jax.tree_util.tree_map(fun, x, l1reg) diff --git a/src/nemos/regularizer.py b/src/nemos/regularizer.py index 7c9268cf..91e59f51 100644 --- a/src/nemos/regularizer.py +++ b/src/nemos/regularizer.py @@ -120,6 +120,8 @@ class are defined in the `allowed_solvers` attribute. "LBFGS", "NonlinearCG", "ProximalGradient", + "SVRG", + "ProxSVRG", ) _default_solver = "GradientDescent" @@ -165,6 +167,8 @@ class Ridge(Regularizer): "LBFGS", "NonlinearCG", "ProximalGradient", + "SVRG", + "ProxSVRG", ) _default_solver = "GradientDescent" @@ -242,7 +246,10 @@ class Lasso(Regularizer): set for L1 regularization (Lasso). It utilizes the `jaxopt` library's proximal gradient optimizer. """ - _allowed_solvers = ("ProximalGradient",) + _allowed_solvers = ( + "ProximalGradient", + "ProxSVRG", + ) _default_solver = "ProximalGradient" @@ -351,7 +358,10 @@ class GroupLasso(Regularizer): >>> print(f"coeff: {model.coef_}") """ - _allowed_solvers = ("ProximalGradient",) + _allowed_solvers = ( + "ProximalGradient", + "ProxSVRG", + ) _default_solver = "ProximalGradient" diff --git a/src/nemos/simulation.py b/src/nemos/simulation.py index 9dc21010..ee80547f 100644 --- a/src/nemos/simulation.py +++ b/src/nemos/simulation.py @@ -60,7 +60,7 @@ def difference_of_gammas( References ---------- - 1. [SciPy Docs - "scipy.stats.gamma"](https://docs.scipy.org/doc/ + [1] [SciPy Docs - "scipy.stats.gamma"](https://docs.scipy.org/doc/ scipy/reference/generated/scipy.stats.gamma.html) """ # check that the gamma parameters are positive (scipy returns diff --git a/src/nemos/solvers.py b/src/nemos/solvers.py new file mode 100644 index 00000000..d1b2deeb --- /dev/null +++ b/src/nemos/solvers.py @@ -0,0 +1,770 @@ +from functools import partial +from typing import Callable, NamedTuple, Optional, Union + +import jax +import jax.flatten_util +import jax.numpy as jnp +from jax import grad, jit, lax, random +from jaxopt import OptStep +from jaxopt._src import loop +from jaxopt.prox import prox_none + +from .tree_utils import tree_add_scalar_mul, tree_l2_norm, tree_slice, tree_sub +from .typing import KeyArrayLike, Pytree + + +class SVRGState(NamedTuple): + """ + Optimizer state for (Prox)SVRG. + + Attributes + ---------- + iter_num : + Current epoch or iteration number. + key : + Random key to use when sampling data points or mini-batches. + error : + Scaled difference (~distance) between subsequent parameter values + used to monitor convergence. + stepsize : + Step size of the individual gradient steps. + reference_point : + Anchor/reference/snapshot point where the full gradient is calculated in the SVRG algorithm. + Corresponds to $x_{s}$ in the pseudocode[$^{[1]}$](#references). + full_grad_at_reference_point : + Full gradient at the anchor/reference point. + + # References + ------------ + [1] [Gower, Robert M., Mark Schmidt, Francis Bach, and Peter Richtárik. + "Variance-Reduced Methods for Machine Learning." arXiv preprint arXiv:2010.00892 (2020). + ](https://arxiv.org/abs/2010.00892) + """ + + iter_num: int + key: KeyArrayLike + error: float + stepsize: float + reference_point: Optional[Pytree] = None + full_grad_at_reference_point: Optional[Pytree] = None + + +class ProxSVRG: + """ + Prox-SVRG solver + + Borrowing from jaxopt.ProximalGradient, this solver minimizes: + + objective(params, hyperparams_prox, *args, **kwargs) = + fun(params, *args, **kwargs) + non_smooth(params, hyperparams_prox) + + Attributes + ---------- + fun: Callable + Smooth function of the form ``fun(x, *args, **kwargs)``. + prox: Callable + Proximal operator associated with the function ``non_smooth``. + It should be of the form ``prox(params, hyperparams_prox, scale=1.0)``. + See ``jaxopt.prox`` for examples. + maxiter : int + Maximum number of epochs to run the optimization for. + key : jax.random.PRNGkey + jax PRNGKey to start with. Used for sampling random data points. + stepsize : float + Constant step size to use. + tol: float + Tolerance level for the error when comparing parameters + at the end of consecutive epochs to check for convergence. + batch_size: int + Number of data points to sample per inner loop iteration. + + Examples + -------- + >>> def loss_fn(params, X, y): + >>> ... + >>> + >>> svrg = ProxSVRG(loss_fn, prox_fun) + >>> params, state = svrg.run(init_params, hyperparams_prox, X, y) + + References + ---------- + [1] [Gower, Robert M., Mark Schmidt, Francis Bach, and Peter Richtárik. + "Variance-Reduced Methods for Machine Learning." arXiv preprint arXiv:2010.00892 (2020). + ](https://arxiv.org/abs/2010.00892) + + [2] [Xiao, Lin, and Tong Zhang. + "A proximal stochastic gradient method with progressive variance reduction." + SIAM Journal on Optimization 24.4 (2014): 2057-2075.](https://arxiv.org/abs/1403.4699v1) + + [3] [Johnson, Rie, and Tong Zhang. + "Accelerating stochastic gradient descent using predictive variance reduction." + Advances in neural information processing systems 26 (2013). + ](https://proceedings.neurips.cc/paper/2013/hash/ac1dd209cbcc5e5d1c6e28598e8cbbe8-Abstract.html) + """ + + def __init__( + self, + fun: Callable, + prox: Callable, + maxiter: int = 10_000, + key: Optional[KeyArrayLike] = None, + stepsize: float = 1e-3, + tol: float = 1e-3, + batch_size: int = 1, + ): + self.fun = fun + self.maxiter = maxiter + self.key = key + self.stepsize = stepsize + self.tol = tol + self.loss_gradient = jit(grad(self.fun)) + self.batch_size = batch_size + self.proximal_operator = prox + + def init_state( + self, + init_params: Pytree, + *args, + ) -> SVRGState: + """ + Initialize the solver state + + Parameters + ---------- + init_params : + Pytree containing the initial parameters. + For GLMs it's a tuple of (W, b) + args: + Positional arguments passed to loss function `fun` and its gradient (e.g. `fun(params, *args)`), + most likely input and output data. + They are expected to be Pytrees with arrays or FeaturePytree as their leaves, with all of their + leaves having the same sized first dimension (corresponding to the number of data points). + For GLMs these are: + X : DESIGN_INPUT_TYPE + Input data. + y : jnp.ndarray + Output data. + + Returns + ------- + state : + Initialized optimizer state + """ + state = SVRGState( + iter_num=0, + key=self.key if self.key is not None else random.key(123), + error=jnp.inf, + stepsize=self.stepsize, + reference_point=init_params, + full_grad_at_reference_point=None, + ) + return state + + @partial(jit, static_argnums=(0,)) + def _inner_loop_param_update_step( + self, + params: Pytree, + reference_point: Pytree, + full_grad_at_reference_point: Pytree, + stepsize: float, + hyperparams_prox: Union[float, None], + *args, + ) -> Pytree: + """ + Body of the inner loop of Prox-SVRG that takes a step. + + Parameters + ---------- + params : + Current parameters. + reference_point : + Anchor point. + full_grad_at_reference_point : + Full gradient at the anchor point. + stepsize : + Step size. + hyperparams_prox : + Hyperparameters to `prox`, most commonly regularization strength. + args: + Positional arguments passed to loss function `fun` and its gradient (e.g. `fun(params, *args)`), + most likely input and output data. + They are expected to be Pytrees with arrays or FeaturePytree as their leaves, with all of their + leaves having the same sized first dimension (corresponding to the number of data points). + For GLMs these are: + X : DESIGN_INPUT_TYPE + Input data. + y : jnp.ndarray + Output data. + + Returns + ------- + next_params : + Parameter values after applying the update. + """ + # gradient on batch_{i_k} evaluated at the current parameters + # gradient of f_{i_k} at x_{k} in the pseudocode of Gower et al. 2020 + minibatch_grad_at_current_params = self.loss_gradient(params, *args) + # gradient on batch_{i_k} evaluated at the anchor point + # gradient of f_{i_k} at x_{x} in the pseudocode of Gower et al. 2020 + minibatch_grad_at_reference_point = self.loss_gradient(reference_point, *args) + + # SVRG gradient estimate + gk = jax.tree_util.tree_map( + lambda a, b, c: a - b + c, + minibatch_grad_at_current_params, + minibatch_grad_at_reference_point, + full_grad_at_reference_point, + ) + + # x_{k+1} = x_{k} - stepsize * g_{k} + next_params = tree_add_scalar_mul(params, -stepsize, gk) + + # apply the proximal operator + next_params = self.proximal_operator( + next_params, hyperparams_prox, scaling=stepsize + ) + + return next_params + + @partial(jit, static_argnums=(0,)) + def update( + self, + params: Pytree, + state: SVRGState, + hyperparams_prox: Union[float, None], + *args, + ) -> OptStep: + """ + Perform a single parameter update on the passed data (no random sampling or loops) + and increment `state.iter_num`. + + Please note that this gets called by `BaseRegressor._solver_update` (e.g., as called by `GLM.update`), + but repeated calls to `(Prox)SVRG.update` (so in turn e.g. to `GLM.update`) on mini-batches passed to it + will not result in running the full (Prox-)SVRG, and parts of the algorithm will have to be implemented outside. + + Parameters + ---------- + params : + Parameters at the end of the previous update, used as the starting point for the current update. + state : + Optimizer state at the end of the previous update. + Needs to have the current anchor point (`reference_point`) and the gradient at the anchor point + (`full_grad_at_reference_point`) already set. + hyperparams_prox : + Hyperparameters to `prox`, most commonly regularization strength. + args: + Positional arguments passed to loss function `fun` and its gradient (e.g. `fun(params, *args)`), + most likely input and output data. + They are expected to be Pytrees with arrays or FeaturePytree as their leaves, with all of their + leaves having the same sized first dimension (corresponding to the number of data points). + For GLMs these are: + X : DESIGN_INPUT_TYPE + Input data. + y : jnp.ndarray + Output data. + + + Returns + ------- + OptStep + reference_point : + Parameters after taking one step defined in the inner loop of Prox-SVRG. + state : + Updated state. + + Raises + ------ + ValueError + The parameter update needs a value for the full gradient at the anchor point, which needs the full data + to be calculated and is expected to be stored in `state.full_grad_at_reference_point`. So if + `state.full_grad_at_reference_point` is None, a ValueError is raised. + """ + if state.full_grad_at_reference_point is None: + raise ValueError( + "Full gradient at the anchor point (state.full_grad_at_reference_point) has to be set." + ) + return self._update_on_batch(params, state, hyperparams_prox, *args) + + @partial(jit, static_argnums=(0,)) + def _update_on_batch( + self, + params: Pytree, + state: SVRGState, + hyperparams_prox: Union[float, None], + *args, + ) -> OptStep: + """ + Update parameters given a mini-batch of data and increment iteration/epoch number in state. + + Note that this method doesn't update `state.reference_point`, `state.full_grad_at_reference_point`, + that has to be done outside. + + Parameters + ---------- + params : + Parameters at the end of the previous update, used as the starting point for the current update. + state : + Optimizer state at the end of the previous update. + Needs to have the current anchor point (`reference_point`) and the gradient at the anchor point + (`full_grad_at_reference_point`) already set. + hyperparams_prox : + Hyperparameters to `prox`, most commonly regularization strength. + args: + Positional arguments passed to loss function `fun` and its gradient (e.g. `fun(params, *args)`), + most likely input and output data. + They are expected to be Pytrees with arrays or FeaturePytree as their leaves, with all of their + leaves having the same sized first dimension (corresponding to the number of data points). + For GLMs these are: + X : DESIGN_INPUT_TYPE + Input data. + y : jnp.ndarray + Output data. + + Returns + ------- + OptStep + reference_point : + Parameters after taking one step defined in the inner loop of Prox-SVRG. + state : + Updated state. + """ + next_params = self._inner_loop_param_update_step( + params, + state.reference_point, + state.full_grad_at_reference_point, + state.stepsize, + hyperparams_prox, + *args, + ) + + state = state._replace( + iter_num=state.iter_num + 1, + ) + + return OptStep(params=next_params, state=state) + + @partial(jit, static_argnums=(0,)) + def run( + self, + init_params: Pytree, + hyperparams_prox: Union[float, None], + *args, + ) -> OptStep: + """ + Run a whole optimization until convergence or until `maxiter` epochs are reached. + Called by `BaseRegressor._solver_run` (e.g. as called by `GLM.fit`) and assumes + that X and y are the full data set. + + Parameters + ---------- + init_params : + Initial parameters to start from. + hyperparams_prox : + Hyperparameters to `prox`, most commonly regularization strength. + args: + Positional arguments passed to loss function `fun` and its gradient (e.g. `fun(params, *args)`), + most likely input and output data. + They are expected to be Pytrees with arrays or FeaturePytree as their leaves, with all of their + leaves having the same sized first dimension (corresponding to the number of data points). + For GLMs these are: + X : DESIGN_INPUT_TYPE + Input data. + y : jnp.ndarray + Output data. + + Returns + ------- + OptStep + final_params : + Parameters at the end of the last innner loop. + (... or the average of the parameters over the last inner loop) + final_state : + Final optimizer state. + """ + # initialize the state, including the full gradient at the initial parameters + init_state = self.init_state( + init_params, + *args, + ) + + return self._run(init_params, init_state, hyperparams_prox, *args) + + @partial(jit, static_argnums=(0,)) + def _run( + self, + init_params: Pytree, + init_state: SVRGState, + hyperparams_prox: Union[float, None], + *args, + ) -> OptStep: + """ + Run a whole optimization until convergence or until `maxiter` epochs are reached. + Called by `BaseRegressor._solver_run` (e.g. as called by `GLM.fit`) and assumes that + X and y are the full data set. + Assumes the state has been initialized, which works a bit differently for SVRG and ProxSVRG. + + Parameters + ---------- + init_params : + Initial parameters to start from. + init_state : + Initialized optimizer state returned by `ProxSVRG.init_state` + hyperparams_prox : + Hyperparameters to `prox`, most commonly regularization strength. + args: + Positional arguments passed to loss function `fun` and its gradient (e.g. `fun(params, *args)`), + most likely input and output data. + They are expected to be Pytrees with arrays or FeaturePytree as their leaves, with all of their + leaves having the same sized first dimension (corresponding to the number of data points). + For GLMs these are: + X : DESIGN_INPUT_TYPE + Input data. + y : jnp.ndarray + Output data. + Returns + ------- + OptStep + final_params : + Parameters at the end of the last innner loop. + (... or the average of the parameters over the last inner loop) + final_state : + Final optimizer state. + """ + + # this method assumes that args hold the full data + def body_fun(step): + prev_reference_point, state = step + + # evaluate and store the full gradient with the params from the last inner loop + state = state._replace( + full_grad_at_reference_point=self.loss_gradient( + prev_reference_point, *args + ) + ) + + # run an update over the whole data + params, state = self._update_per_random_samples( + prev_reference_point, state, hyperparams_prox, *args + ) + + # update reference point (x_{s}) with the final parameters (x_{m}) or an average over + # the inner loop's iterations + # note that the average is currently not implemented + reference_point = params + + state = state._replace( + reference_point=reference_point, + error=self._error( + reference_point, prev_reference_point, state.stepsize + ), + ) + + return OptStep(params=reference_point, state=state) + + # at the end of each epoch, check for convergence or reaching the max number of epochs + def cond_fun(step): + _, state = step + return (state.iter_num <= self.maxiter) & (state.error >= self.tol) + + # initialize the full gradient at the anchor point + # the anchor point is init_params at first + init_state = init_state._replace( + full_grad_at_reference_point=self.loss_gradient(init_params, *args) + ) + + final_params, final_state = loop.while_loop( + cond_fun=cond_fun, + body_fun=body_fun, + init_val=OptStep(params=init_params, state=init_state), + maxiter=self.maxiter, + jit=True, + ) + return OptStep(params=final_params, state=final_state) + + @partial(jit, static_argnums=(0,)) + def _update_per_random_samples( + self, + params: Pytree, + state: SVRGState, + hyperparams_prox: Union[float, None], + *args, + ) -> OptStep: + """ + Performs the inner loop of Prox-SVRG sweeping through approximately one full epoch, + updating the parameters after sampling a mini-batch on each iteration. + + Parameters + ---------- + params : + Parameters at the end of the previous update, used as the starting point for the current update. + state : + Optimizer state at the end of the previous sweep. + Needs to have the current anchor point (`reference_point`) and the gradient at the anchor point + (`full_grad_at_reference_point`) already set. + hyperparams_prox : + Hyperparameters to `prox`, most commonly regularization strength. Can be None. + args : + Positional arguments passed to loss function `fun` and its gradient (e.g. `fun(params, *args)`), + most likely input and output data. + They are expected to be Pytrees with arrays or FeaturePytree as their leaves, with all of their + leaves having the same sized first dimension (corresponding to the number of data points). + For GLMs these are: + X : DESIGN_INPUT_TYPE + Input data. + y : jnp.ndarray + Output data. + + Returns + ------- + OptStep + next_params : + Parameters at the end of the last inner loop. + (... or the average of the parameters over the last inner loop) + state : + Updated state. + + Raises + ------ + ValueError + If not all arguments in args have the same sized first dimension. + """ + n_points_per_arg = {leaf.shape[0] for leaf in jax.tree.leaves(args)} + if not len(n_points_per_arg) == 1: + raise ValueError("All arguments must have the same sized first dimension.") + N = n_points_per_arg.pop() + + m = (N + self.batch_size - 1) // self.batch_size # number of iterations + + def inner_loop_body(_, carry): + params, key = carry + + # sample mini-batch or data point + key, subkey = random.split(key) + ind = random.randint(subkey, (self.batch_size,), 0, N) + + # perform a single update on the mini-batch or data point + next_params = self._inner_loop_param_update_step( + params, + state.reference_point, + state.full_grad_at_reference_point, + state.stepsize, + hyperparams_prox, + *tree_slice(args, ind), + ) + + return (next_params, key) + + next_params, key = lax.fori_loop( + 0, + m, + inner_loop_body, + (params, state.key), + ) + + # update the state + # storing the average over the inner loop to potentially use it in the run loop + state = state._replace( + iter_num=state.iter_num + 1, + key=key, + ) + + return OptStep(params=next_params, state=state) + + @staticmethod + def _error(x, x_prev, stepsize): + """ + Calculate the magnitude of the update relative to the parameters. + Used for terminating the algorithm if a certain tolerance is reached. + + Params + ------ + x : + Parameter values after the update. + x_prev : + Previous parameter values. + + Returns + ------- + Scaled update magnitude. + """ + # stepsize is an argument to be consistent with jaxopt + return tree_l2_norm(tree_sub(x, x_prev)) / tree_l2_norm(x_prev) + + +class SVRG(ProxSVRG): + """ + SVRG solver + + Equivalent to ProxSVRG with prox as the identity function and hyperparams_prox=None. + + Attributes + ---------- + fun: Callable + smooth function of the form ``fun(x, *args, **kwargs)``. + maxiter : int + Maximum number of epochs to run the optimization for. + key : jax.random.PRNGkey + jax PRNGKey to start with. Used for sampling random data points. + stepsize : float + Constant step size to use. + tol: float + Tolerance level for the error when comparing parameters + at the end of consecutive epochs to check for convergence. + batch_size: int + Number of data points to sample per inner loop iteration. + + Examples + -------- + >>> def loss_fn(params, X, y): + >>> ... + >>> + >>> svrg = SVRG(loss_fn) + >>> params, state = svrg.run(init_params, X, y) + + References + ---------- + [1] [Gower, Robert M., Mark Schmidt, Francis Bach, and Peter Richtárik. + "Variance-Reduced Methods for Machine Learning." arXiv preprint arXiv:2010.00892 (2020). + ](https://arxiv.org/abs/2010.00892) + + [2] [Xiao, Lin, and Tong Zhang. "A proximal stochastic gradient method with progressive variance reduction." + SIAM Journal on Optimization 24.4 (2014): 2057-2075.](https://arxiv.org/abs/1403.4699v1) + + [3] [Johnson, Rie, and Tong Zhang. "Accelerating stochastic gradient descent using predictive variance reduction." + Advances in neural information processing systems 26 (2013). + ](https://proceedings.neurips.cc/paper/2013/hash/ac1dd209cbcc5e5d1c6e28598e8cbbe8-Abstract.html) + """ + + def __init__( + self, + fun: Callable, + maxiter: int = 10_000, + key: Optional[KeyArrayLike] = None, + stepsize: float = 1e-3, + tol: float = 1e-3, + batch_size: int = 1, + ): + super().__init__( + fun, + prox_none, + maxiter, + key, + stepsize, + tol, + batch_size, + ) + + def init_state(self, init_params: Pytree, *args, **kwargs) -> SVRGState: + """ + Initialize the solver state + + Parameters + ---------- + init_params : + pytree containing the initial parameters. + args: + Positional arguments passed to loss function `fun` and its gradient (e.g. `fun(params, *args)`), + most likely input and output data. + They are expected to be Pytrees with arrays or FeaturePytree as their leaves, with all of their + leaves having the same sized first dimension (corresponding to the number of data points). + For GLMs these are: + X : DESIGN_INPUT_TYPE + Input data. + y : jnp.ndarray + Output data. + + Returns + ------- + state : + Initialized optimizer state + """ + return super().init_state(init_params, *args, **kwargs) + + @partial(jit, static_argnums=(0,)) + def update(self, params: Pytree, state: SVRGState, *args, **kwargs) -> OptStep: + """ + Perform a single parameter update on the passed data (no random sampling or loops) + and increment `state.iter_num`. + + Please note that this gets called by `BaseRegressor._solver_update` (e.g., as called by `GLM.update`), + but repeated calls to `(Prox)SVRG.update` (so in turn e.g. to `GLM.update`) on mini-batches passed to it + will not result in running the full (Prox-)SVRG, and parts of the algorithm will have to be implemented outside. + + Parameters + ---------- + params : + Parameters at the end of the previous update, used as the starting point for the current update. + state : + Optimizer state at the end of the previous update. + Needs to have the current anchor point (`reference_point`) and the gradient at the anchor point + (`full_grad_at_reference_point`) already set. + args: + Positional arguments passed to loss function `fun` and its gradient (e.g. `fun(params, *args)`), + most likely input and output data. + They are expected to be Pytrees with arrays or FeaturePytree as their leaves, with all of their + leaves having the same sized first dimension (corresponding to the number of data points). + For GLMs these are: + X : DESIGN_INPUT_TYPE + Input data. + y : jnp.ndarray + Output data. + + Returns + ------- + OptStep + reference_point : + Parameters after taking one step defined in the inner loop of Prox-SVRG. + state : + Updated state. + + Raises + ------ + ValueError + The parameter update needs a value for the full gradient at the anchor point, which needs the full data + to be calculated and is expected to be stored in `state.full_grad_at_reference_point`. + So if `state.full_grad_at_reference_point` is None, a ValueError is raised. + """ + # substitute None for hyperparams_prox + return super().update(params, state, None, *args, **kwargs) + + @partial(jit, static_argnums=(0,)) + def run( + self, + init_params: Pytree, + *args, + ) -> OptStep: + """ + Run a whole optimization until convergence or until `maxiter` epochs are reached. + Called by `BaseRegressor._solver_run` (e.g. as called by `GLM.fit`) and assumes that + X and y are the full data set. + + Parameters + ---------- + init_params : + Initial parameters to start from. + args: + Positional arguments passed to loss function `fun` and its gradient (e.g. `fun(params, *args)`), + most likely input and output data. + They are expected to be Pytrees with arrays or FeaturePytree as their leaves, with all of their + leaves having the same sized first dimension (corresponding to the number of data points). + For GLMs these are: + X : DESIGN_INPUT_TYPE + Input data. + y : jnp.ndarray + Output data. + + Returns + ------- + OptStep + final_params : + Parameters at the end of the last innner loop. + (... or the average of the parameters over the last inner loop) + final_state : + Final optimizer state. + """ + # initialize the state, including the full gradient at the initial parameters + # don't have to pass hyperparams_prox here + init_state = self.init_state(init_params, *args) + + # substitute None for hyperparams_prox + return self._run(init_params, init_state, None, *args) diff --git a/src/nemos/tree_utils.py b/src/nemos/tree_utils.py index bd2b7267..dc8ced29 100644 --- a/src/nemos/tree_utils.py +++ b/src/nemos/tree_utils.py @@ -1,6 +1,7 @@ """Utilities for manipulating and checking PyTrees.""" -from functools import reduce +import operator +from functools import partial, reduce from typing import Any, Callable, Optional import jax @@ -140,3 +141,65 @@ def pytree_map_and_reduce( cond_tree = jax.tree_util.tree_map(map_fn, *pytrees, is_leaf=is_leaf) # for some reason, tree_reduce doesn't work well with any. return reduce_fn(jax.tree_util.tree_leaves(cond_tree)) + + +def tree_slice(data: Any, idx): + """ + Apply an indexing operation to each array in a nested structure. + + Parameters + ---------- + data : + A nested structure containing arrays (e.g., a dictionary of arrays). + idx : + The indexing operation to apply. This can be an integer, slice, + NumPy array (boolean or integer), tuple of indexing operations, ellipsis, or None. + + Returns + ------- + Any + A nested structure with the same format as `data`, where each array has been sliced according to `idx`. + """ + return jax.tree_util.tree_map(lambda x: x[idx], data) + + +# The following functions are adapted from jaxopt.tree_utils + +tree_add = partial(jax.tree_util.tree_map, operator.add) +tree_add.__doc__ = "Tree addition." + +tree_sub = partial(jax.tree_util.tree_map, operator.sub) +tree_sub.__doc__ = "Tree subtraction." + + +def tree_scalar_mul(scalar, tree_x): + """Compute scalar * tree_x.""" + return jax.tree_util.tree_map(lambda x: scalar * x, tree_x) + + +def tree_add_scalar_mul(tree_x, scalar, tree_y): + """Compute tree_x + scalar * tree_y.""" + return jax.tree_util.tree_map(lambda x, y: x + scalar * y, tree_x, tree_y) + + +def tree_sum(tree_x): + """Compute sum(tree_x).""" + sums = jax.tree_util.tree_map(jnp.sum, tree_x) + return jax.tree_util.tree_reduce(operator.add, sums) + + +def tree_l2_norm(tree_x, squared=False): + """Compute the l2 norm ||tree_x||.""" + squared_tree = jax.tree_util.tree_map( + lambda leaf: jnp.square(leaf.real) + jnp.square(leaf.imag), tree_x + ) + sqnorm = tree_sum(squared_tree) + if squared: + return sqnorm + else: + return jnp.sqrt(sqnorm) + + +def tree_zeros_like(tree_x): + """Creates an all-zero tree with the same structure as tree_x.""" + return jax.tree_util.tree_map(jnp.zeros_like, tree_x) diff --git a/src/nemos/typing.py b/src/nemos/typing.py index 792655a0..dd9bc5a6 100644 --- a/src/nemos/typing.py +++ b/src/nemos/typing.py @@ -4,11 +4,17 @@ import jax.numpy as jnp import jaxopt +from jax._src.typing import ArrayLike from .pytrees import FeaturePytree DESIGN_INPUT_TYPE = Union[jnp.ndarray, FeaturePytree] +Pytree = Any + +# copying jax.random's annotation +KeyArrayLike = ArrayLike + SolverRun = Callable[ [ Any, # parameters, could be any pytree diff --git a/src/nemos/utils.py b/src/nemos/utils.py index 6fa8f872..87ad0472 100644 --- a/src/nemos/utils.py +++ b/src/nemos/utils.py @@ -383,9 +383,9 @@ def row_wise_kron( This function computes the row-wise Kronecker product between dense matrices A and C using JAX for automatic differentiation and GPU acceleration. - References - ---------- - 1. Petersen, Kaare Brandt, and Michael Syskind Pedersen. "The matrix cookbook." + # References + ------------ + [1] Petersen, Kaare Brandt, and Michael Syskind Pedersen. "The matrix cookbook." Technical University of Denmark 7.15 (2008): 510. """ if transpose: diff --git a/tests/conftest.py b/tests/conftest.py index 6693bd58..08815bf0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -542,3 +542,83 @@ def gamma_population_GLM_model_pytree(gamma_population_GLM_model): observation_model=model.observation_model, regularizer=model.regularizer ) return X_tree, spikes, model_tree, true_params_tree, rate + + +@pytest.fixture +def regr_data(): + np.random.seed(123) + # define inputs and coeff + n_samples, n_features = 50, 3 + X = np.random.normal(size=(n_samples, n_features)) + coef = np.random.normal(size=(n_features)) + # set y according to lin reg eqn + y = X.dot(coef) + 0.1 * np.random.normal(size=(n_samples,)) + return X, y, coef + + +@pytest.fixture +def linear_regression(regr_data): + X, y, coef = regr_data + # solve least-squares + ols, _, _, _ = np.linalg.lstsq(X, y, rcond=-1) + + # set the loss + def loss(params, X, y): + return jnp.power(y - jnp.dot(X, params), 2).mean() + + return X, y, coef, ols, loss + + +@pytest.fixture +def ridge_regression(regr_data): + X, y, coef = regr_data + + # solve least-squares + yagu = np.hstack((y, np.zeros_like(coef))) + Xagu = np.vstack((X, np.sqrt(0.5) * np.eye(coef.shape[0]))) + ridge, _, _, _ = np.linalg.lstsq(Xagu, yagu, rcond=-1) + + # set the loss + def loss(params, XX, yy): + return ( + jnp.power(yy - jnp.dot(XX, params), 2).sum() + + 0.5 * jnp.power(params, 2).sum() + ) + + return X, y, coef, ridge, loss + + +@pytest.fixture +def linear_regression_tree(linear_regression): + X, y, coef, ols, loss = linear_regression + X_tree = dict(input_1=X[..., :2], input_2=X[..., 2:]) + coef_tree = dict(input_1=coef[:2], input_2=coef[2:]) + ols_tree = dict(input_1=ols[:2], input_2=ols[2:]) + + nmo.tree_utils.pytree_map_and_reduce(jnp.dot, sum, X_tree, coef_tree) + + def loss_tree(params, XX, yy): + pred = nmo.tree_utils.pytree_map_and_reduce(jnp.dot, sum, XX, params) + return jnp.power(yy - pred, 2).sum() + + return X_tree, y, coef_tree, ols_tree, loss_tree + + +@pytest.fixture() +def ridge_regression_tree(ridge_regression): + X, y, coef, ridge, loss = ridge_regression + X_tree = dict(input_1=X[..., :2], input_2=X[..., 2:]) + coef_tree = dict(input_1=coef[:2], input_2=coef[2:]) + ridge_tree = dict(input_1=ridge[:2], input_2=ridge[2:]) + + def loss_tree(params, XX, yy): + pred = nmo.tree_utils.pytree_map_and_reduce(jnp.dot, sum, XX, params) + norm = ( + 0.5 + * nmo.tree_utils.pytree_map_and_reduce( + lambda x: jnp.power(x, 2).sum(), sum, params + ).sum() + ) + return jnp.power(yy - pred, 2).sum() + norm + + return X_tree, y, coef_tree, ridge_tree, loss_tree diff --git a/tests/test_base_class.py b/tests/test_base_class.py index d6ddff02..4af81c30 100644 --- a/tests/test_base_class.py +++ b/tests/test_base_class.py @@ -76,7 +76,7 @@ def test_abstract_class(): def test_invalid_concrete_class(): """Ensure that classes missing implementation of required abstract methods raise errors.""" with pytest.raises(TypeError, match="Can't instantiate abstract"): - model = MockBaseRegressorInvalid() + MockBaseRegressorInvalid() def test_empty_set(mock_regressor): diff --git a/tests/test_basis.py b/tests/test_basis.py index 6f6741fc..6e81d142 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -8,7 +8,6 @@ import pynapple as nap import pytest import sklearn.pipeline as pipeline -import statsmodels.api as sm import utils_testing from sklearn.base import clone as sk_clone from sklearn.model_selection import GridSearchCV @@ -476,23 +475,6 @@ def test_init_mode(self, mode, expectation): with expectation: self.cls(5, mode=mode, window_size=window_size) - @pytest.mark.parametrize( - "mode, ws, expectation", - [ - ("eval", None, does_not_raise()), - ("conv", 2, does_not_raise()), - ("eval", 2, does_not_raise()), - ( - "conv", - None, - pytest.raises(ValueError, match="If the basis is in `conv`"), - ), - ], - ) - def test_init_window_size(self, mode, ws, expectation): - with expectation: - self.cls(5, mode=mode, window_size=ws) - @pytest.mark.parametrize( "mode, ws, expectation", [ @@ -539,7 +521,7 @@ def test_identifiability_constraint_apply(self): def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): - bas = self.cls(5, mode="eval", test="hi") + self.cls(5, mode="eval", test="hi") @pytest.mark.parametrize( @@ -841,7 +823,7 @@ def test_pynapple_support_compute_features(self, n_basis, sample_size): assert isinstance(out, nap.TsdFrame) assert np.all(out.time_support.values == inp.time_support.values) - ## TEST CALL + # TEST CALL @pytest.mark.parametrize( "num_input, expectation", [ @@ -986,23 +968,6 @@ def test_init_mode(self, mode, expectation): with expectation: self.cls(5, mode=mode, window_size=window_size) - @pytest.mark.parametrize( - "mode, ws, expectation", - [ - ("eval", None, does_not_raise()), - ("conv", 2, does_not_raise()), - ("eval", 2, does_not_raise()), - ( - "conv", - None, - pytest.raises(ValueError, match="If the basis is in `conv`"), - ), - ], - ) - def test_init_window_size(self, mode, ws, expectation): - with expectation: - self.cls(5, mode=mode, window_size=ws) - @pytest.mark.parametrize( "mode, ws, expectation", [ @@ -1048,7 +1013,7 @@ def test_identifiability_constraint_apply(self): def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): - bas = self.cls(5, mode="eval", test="hi") + self.cls(5, mode="eval", test="hi") @pytest.mark.parametrize( "bounds, expectation", @@ -1353,7 +1318,7 @@ def test_pynapple_support_compute_features(self, n_basis, sample_size): assert isinstance(out, nap.TsdFrame) assert np.all(out.time_support.values == inp.time_support.values) - ## TEST CALL + # TEST CALL @pytest.mark.parametrize( "num_input, expectation", [ @@ -1493,23 +1458,6 @@ def test_init_mode(self, mode, expectation): with expectation: self.cls(5, mode=mode, window_size=window_size) - @pytest.mark.parametrize( - "mode, ws, expectation", - [ - ("eval", None, does_not_raise()), - ("conv", 2, does_not_raise()), - ("eval", 2, does_not_raise()), - ( - "conv", - None, - pytest.raises(ValueError, match="If the basis is in `conv`"), - ), - ], - ) - def test_init_window_size(self, mode, ws, expectation): - with expectation: - self.cls(5, mode=mode, window_size=ws) - @pytest.mark.parametrize( "mode, ws, expectation", [ @@ -1553,10 +1501,9 @@ def test_identifiability_constraint_apply(self): assert np.allclose(X.mean(axis=0), np.zeros(X.shape[1])) assert X.shape[1] == bas.n_basis_funcs - 1 - def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): - bas = self.cls(5, mode="eval", test="hi") + self.cls(5, mode="eval", test="hi") @pytest.mark.parametrize( @@ -1900,7 +1847,7 @@ def test_pynapple_support_compute_features(self, n_basis, sample_size): assert isinstance(out, nap.TsdFrame) assert np.all(out.time_support.values == inp.time_support.values) - ## TEST CALL + # TEST CALL @pytest.mark.parametrize( "num_input, expectation", [ @@ -2063,23 +2010,6 @@ def test_init_mode(self, mode, expectation): with expectation: self.cls(5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6)) - @pytest.mark.parametrize( - "mode, ws, expectation", - [ - ("eval", None, does_not_raise()), - ("conv", 10, does_not_raise()), - ("eval", 2, does_not_raise()), - ( - "conv", - None, - pytest.raises(ValueError, match="If the basis is in `conv`"), - ), - ], - ) - def test_init_window_size(self, mode, ws, expectation): - with expectation: - self.cls(5, mode=mode, window_size=ws, decay_rates=np.arange(1, 6)) - @pytest.mark.parametrize( "mode, ws, expectation", [ @@ -2315,7 +2245,7 @@ def test_evaluate_on_grid_meshgrid_size(self, sample_size): with pytest.raises( ValueError, match=r"Invalid input data|" - rf"All sample counts provided must be greater", + r"All sample counts provided must be greater", ): basis_obj.evaluate_on_grid(sample_size) else: @@ -2517,23 +2447,6 @@ def test_init_mode(self, mode, expectation): with expectation: self.cls(5, mode=mode, window_size=window_size) - @pytest.mark.parametrize( - "mode, ws, expectation", - [ - ("eval", None, does_not_raise()), - ("conv", 2, does_not_raise()), - ("eval", 2, does_not_raise()), - ( - "conv", - None, - pytest.raises(ValueError, match="If the basis is in `conv`"), - ), - ], - ) - def test_init_window_size(self, mode, ws, expectation): - with expectation: - self.cls(5, mode=mode, window_size=ws) - @pytest.mark.parametrize( "mode, ws, expectation", [ @@ -2579,7 +2492,7 @@ def test_identifiability_constraint_apply(self): def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): - bas = self.cls(5, mode="eval", test="hi") + self.cls(5, mode="eval", test="hi") @pytest.mark.parametrize( @@ -3062,41 +2975,6 @@ def test_transform_fails(self): ): bas._compute_features(np.linspace(0, 1, 10)) - @pytest.mark.parametrize( - "mode, expectation", - [ - ("eval", does_not_raise()), - ("conv", does_not_raise()), - ( - "invalid", - pytest.raises( - ValueError, match="`mode` should be either 'conv' or 'eval'" - ), - ), - ], - ) - def test_init_mode(self, mode, expectation): - window_size = None if mode == "eval" else 2 - with expectation: - self.cls(5, mode=mode, window_size=window_size) - - @pytest.mark.parametrize( - "mode, ws, expectation", - [ - ("eval", None, does_not_raise()), - ("conv", 2, does_not_raise()), - ("eval", 2, does_not_raise()), - ( - "conv", - None, - pytest.raises(ValueError, match="If the basis is in `conv`"), - ), - ], - ) - def test_init_window_size(self, mode, ws, expectation): - with expectation: - self.cls(5, mode=mode, window_size=ws) - @pytest.mark.parametrize( "mode, ws, expectation", [ @@ -3140,10 +3018,9 @@ def test_identifiability_constraint_apply(self): assert np.allclose(X.mean(axis=0), np.zeros(X.shape[1])) assert X.shape[1] == bas.n_basis_funcs - 1 - def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): - bas = self.cls(5, mode="eval", test="hi") + self.cls(5, mode="eval", test="hi") @pytest.mark.parametrize( @@ -3847,7 +3724,8 @@ def test_compute_features_returns_expected_number_of_basis( ) if eval_basis.shape[1] != basis_a_obj.n_basis_funcs * basis_b_obj.n_basis_funcs: raise ValueError( - "Dimensions do not agree: The number of basis should match the first dimension of the output features." + "Dimensions do not agree: The number of basis should match the first dimension of the " + "fit_transformed basis." f"The number of basis is {n_basis_a * n_basis_b}", f"The first dimension of the output features is {eval_basis.shape[1]}", ) @@ -3862,7 +3740,8 @@ def test_sample_size_of_compute_features_matches_that_of_input( self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, mode, window_size ): """ - Test whether the output sample size from the `MultiplicativeBasis` compute_features function matches the input sample size. + Test whether the output sample size from the `MultiplicativeBasis` fit_transform function + matches the input sample size. """ basis_a_obj = self.instantiate_basis( n_basis_a, basis_a, mode=mode, window_size=window_size diff --git a/tests/test_convergence.py b/tests/test_convergence.py index 226caacd..881ee0ac 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -3,12 +3,16 @@ import jax import numpy as np +import pytest from scipy.optimize import minimize import nemos as nmo -def test_unregularized_convergence(): +@pytest.mark.parametrize( + "solver_names", [("GradientDescent", "ProximalGradient"), ("SVRG", "ProxSVRG")] +) +def test_unregularized_convergence(solver_names): """ Assert that solution found when using GradientDescent vs ProximalGradient with an unregularized GLM is the same. @@ -46,7 +50,10 @@ def test_unregularized_convergence(): assert np.allclose(model_GD.intercept_, model_PG.intercept_) -def test_ridge_convergence(): +@pytest.mark.parametrize( + "solver_names", [("GradientDescent", "ProximalGradient"), ("SVRG", "ProxSVRG")] +) +def test_ridge_convergence(solver_names): """ Assert that solution found when using GradientDescent vs ProximalGradient with an ridge GLM is the same. @@ -85,7 +92,8 @@ def test_ridge_convergence(): assert np.allclose(model_GD.intercept_, model_PG.intercept_) -def test_lasso_convergence(): +@pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) +def test_lasso_convergence(solver_name): """ Assert that solution found when using ProximalGradient versus Nelder-Mead method using lasso GLM is the same. @@ -128,7 +136,8 @@ def test_lasso_convergence(): assert np.allclose(res.x[:1], model_PG.intercept_) -def test_group_lasso_convergence(): +@pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) +def test_group_lasso_convergence(solver_name): """ Assert that solution found when using ProximalGradient versus Nelder-Mead method using group lasso GLM is the same. diff --git a/tests/test_glm.py b/tests/test_glm.py index e32de8f8..c9afef1c 100644 --- a/tests/test_glm.py +++ b/tests/test_glm.py @@ -7,10 +7,12 @@ import numpy as np import pytest import statsmodels.api as sm +from sklearn.linear_model import GammaRegressor, PoissonRegressor from sklearn.model_selection import GridSearchCV import nemos as nmo from nemos.pytrees import FeaturePytree +from nemos.tree_utils import pytree_map_and_reduce, tree_l2_norm, tree_slice, tree_sub def test_validate_higher_dimensional_data_X(mock_glm): @@ -855,27 +857,6 @@ def test_predict_n_feature_consistency_x( with expectation: model.predict(X) - @pytest.mark.parametrize( - "is_fit, expectation", - [ - (True, does_not_raise()), - ( - False, - pytest.raises(ValueError, match="This GLM instance is not fitted yet"), - ), - ], - ) - def test_predict_is_fit(self, is_fit, expectation, poissonGLM_model_instantiation): - """ - Test the `score` method on models based on their fit status. - Ensure scoring is only possible on fitted models. - """ - X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - if is_fit: - model.fit(X, y) - with expectation: - model.predict(X) - ####################### # Test model.initialize_solver ####################### @@ -970,7 +951,9 @@ def test_initialize_solver_intercepts_dimensionality( self, dim_intercepts, expectation, poissonGLM_model_instantiation ): """ - Test the `initialize_solver` method with intercepts of different dimensionalities. Check for correct dimensionality. + Test the `initialize_solver` method with intercepts of different dimensionalities. + + Check for correct dimensionality. """ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation n_samples, n_features = X.shape @@ -1023,8 +1006,9 @@ def test_initialize_solver_init_params_type( self, init_params, expectation, poissonGLM_model_instantiation ): """ - Test the `initialize_solver` method with various types of initial parameters. Ensure that the provided initial parameters - are array-like. + Test the `initialize_solver` method with various types of initial parameters. + + Ensure that the provided initial parameters are array-like. """ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation with expectation: @@ -1043,7 +1027,9 @@ def test_initialize_solver_x_dimensionality( self, delta_dim, expectation, poissonGLM_model_instantiation ): """ - Test the `initialize_solver` method with X input data of different dimensionalities. Ensure correct dimensionality for X. + Test the `initialize_solver` method with X input data of different dimensionalities. + + Ensure correct dimensionality for X. """ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation if delta_dim == -1: @@ -1066,7 +1052,9 @@ def test_initialize_solver_y_dimensionality( self, delta_dim, expectation, poissonGLM_model_instantiation ): """ - Test the `initialize_solver` method with y target data of different dimensionalities. Ensure correct dimensionality for y. + Test the `initialize_solver` method with y target data of different dimensionalities. + + Ensure correct dimensionality for y. """ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation if delta_dim == -1: @@ -1143,7 +1131,9 @@ def test_initialize_solver_time_points_x( self, delta_tp, expectation, poissonGLM_model_instantiation ): """ - Test the `initialize_solver` method for inconsistencies in time-points in data X. Ensure the correct number of time-points. + Test the `initialize_solver` method for inconsistencies in time-points in data X. + + Ensure the correct number of time-points. """ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation X = jnp.zeros((X.shape[0] + delta_tp,) + X.shape[1:]) @@ -1169,7 +1159,9 @@ def test_initialize_solver_time_points_y( self, delta_tp, expectation, poissonGLM_model_instantiation ): """ - Test the `initialize_solver` method for inconsistencies in time-points in y. Ensure the correct number of time-points. + Test the `initialize_solver` method for inconsistencies in time-points in y. + + Ensure the correct number of time-points. """ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation y = jnp.zeros((y.shape[0] + delta_tp,) + y.shape[1:]) @@ -1445,6 +1437,172 @@ def test_compatibility_with_sklearn_cv_gamma(self, gammaGLM_model_instantiation) param_grid = {"solver_name": ["BFGS", "GradientDescent"]} GridSearchCV(model, param_grid).fit(X, y) + @pytest.mark.parametrize( + "regr_setup, glm_class", + [ + ("poissonGLM_model_instantiation", nmo.glm.GLM), + ("poissonGLM_model_instantiation_pytree", nmo.glm.GLM), + ("poisson_population_GLM_model", nmo.glm.PopulationGLM), + ("poisson_population_GLM_model_pytree", nmo.glm.PopulationGLM), + ], + ) + @pytest.mark.parametrize( + "key", [jax.random.key(0), jax.random.key(19)] + ) + @pytest.mark.parametrize( + "regularizer_class, solver_name", + [ + (nmo.regularizer.UnRegularized, "SVRG"), + (nmo.regularizer.Ridge, "SVRG"), + (nmo.regularizer.Lasso, "ProxSVRG"), + # (nmo.regularizer.GroupLasso, "ProxSVRG"), + ] + ) + def test_glm_update_consistent_with_fit_with_svrg(self, request, regr_setup, glm_class, key, regularizer_class, + solver_name): + """ + Make sure that calling GLM.update with the rest of the algorithm implemented outside in a naive loop + is consistent with running the compiled GLM.fit on the same data with the same parameters + """ + jax.config.update("jax_enable_x64", True) + X, y, model, true_params, rate = request.getfixturevalue(regr_setup) + + N = y.shape[0] + batch_size = 1 + maxiter = 3 # number of epochs + tol = 1e-12 + stepsize = 1e-3 + + # has to match how the number of iterations is calculated in SVRG + m = int((N + batch_size - 1) // batch_size) + + regularizer_kwargs = {} + if regularizer_class.__name__ == "GroupLasso": + n_features = sum(x.shape[1] for x in jax.tree.leaves(X)) + regularizer_kwargs["mask"] = (np.random.randn(n_features) > 0).reshape(1, -1).astype(float) + + glm = glm_class( + regularizer=regularizer_class( + **regularizer_kwargs, + ), + solver_name=solver_name, + solver_kwargs={ + "batch_size": batch_size, + "stepsize": stepsize, + "tol": tol, + "maxiter": maxiter, + "key": key, + }, + ) + glm2 = glm_class( + regularizer=regularizer_class( + **regularizer_kwargs, + ), + solver_name=solver_name, + solver_kwargs={ + "batch_size": batch_size, + "stepsize": stepsize, + "tol": tol, + "maxiter": maxiter, + "key": key, + }, + ) + glm2.fit(X, y) + + params = glm.initialize_params(X, y) + state = glm.initialize_state(X, y, params) + glm.instantiate_solver() + + # NOTE these two are not the same because for example Ridge augments the loss + # loss_grad = jax.jit(jax.grad(glm._predict_and_compute_loss)) + loss_grad = jax.jit(jax.grad(glm._solver_loss_fun_)) + + # copied from GLM.fit + # grab data if needed (tree map won't function because param is never a FeaturePytree). + if isinstance(X, FeaturePytree): + X = X.data + + iter_num = 0 + while iter_num < maxiter: + state = state._replace( + full_grad_at_reference_point=loss_grad(params, X, y), + ) + + prev_params = params + for _ in range(m): + key, subkey = jax.random.split(key) + ind = jax.random.randint(subkey, (batch_size,), 0, N) + xi, yi = tree_slice(X, ind), tree_slice(y, ind) + params, state = glm.update(params, state, xi, yi) + + state = state._replace( + reference_point=params, + ) + + iter_num += 1 + + _error = tree_l2_norm(tree_sub(params, prev_params)) / tree_l2_norm(prev_params) + if _error < tol: + break + + assert iter_num == glm2.solver_state_.iter_num + + assert pytree_map_and_reduce( + lambda a, b: np.allclose(a, b, atol=10**-5, rtol=0.0), + all, + (glm.coef_, glm.intercept_), + (glm2.coef_, glm2.intercept_), + ) + + @pytest.mark.parametrize("solver_name", ["GradientDescent", "SVRG"]) + def test_glm_fit_matches_sklearn_poisson(self, solver_name, poissonGLM_model_instantiation): + """Test that different solvers converge to the same solution.""" + jax.config.update("jax_enable_x64", True) + X, y, _, true_params, firing_rate = poissonGLM_model_instantiation + + model = nmo.glm.GLM( + regularizer=nmo.regularizer.UnRegularized(), + observation_model=nmo.observation_models.PoissonObservations(), + solver_name=solver_name, + solver_kwargs={"tol": 10**-12} + ) + # set precision to float64 for accurate matching of the results + model.data_type = jnp.float64 + model.fit(X, y) + + model_skl = PoissonRegressor(fit_intercept=True, tol=10**-12, alpha=0.0) + model_skl.fit(X, y) + + match_weights = jnp.allclose(model_skl.coef_, model.coef_, atol=1e-5, rtol=0.) + match_intercepts = jnp.allclose(model_skl.intercept_, model.intercept_, atol=1e-5, rtol=0.) + if (not match_weights) or (not match_intercepts): + raise ValueError("GLM.fit estimate does not match sklearn!") + + @pytest.mark.parametrize("solver_name", ["GradientDescent", "SVRG"]) + def test_glm_fit_matches_sklearn_gamma(self, solver_name, gammaGLM_model_instantiation): + """Test that different solvers converge to the same solution.""" + jax.config.update("jax_enable_x64", True) + X, y, _, true_params, firing_rate = gammaGLM_model_instantiation + + model = nmo.glm.GLM( + regularizer=nmo.regularizer.UnRegularized(), + observation_model=nmo.observation_models.GammaObservations(inverse_link_function=jnp.exp), + solver_name=solver_name, + solver_kwargs={"tol": 10**-12}, + ) + # set precision to float64 for accurate matching of the results + model.data_type = jnp.float64 + model.fit(X, y) + + model_skl = GammaRegressor(fit_intercept=True, tol=10**-12, alpha=0.0) + model_skl.fit(X, y) + + match_weights = jnp.allclose(model_skl.coef_, model.coef_, atol=1e-5, rtol=0.) + match_intercepts = jnp.allclose(model_skl.intercept_, model.intercept_, atol=1e-5, rtol=0.) + + if (not match_weights) or (not match_intercepts): + raise ValueError("GLM.fit estimate does not match sklearn!") + @pytest.mark.parametrize( "reg, dof", [ @@ -2154,7 +2312,9 @@ def test_initialize_solver_intercepts_dimensionality( self, dim_intercepts, expectation, poisson_population_GLM_model ): """ - Test the `initialize_solver` method with intercepts of different dimensionalities. Check for correct dimensionality. + Test the `initialize_solver` method with intercepts of different dimensionalities. + + Check for correct dimensionality. """ X, y, model, true_params, firing_rate = poisson_population_GLM_model n_samples, n_features = X.shape @@ -2198,8 +2358,9 @@ def test_initialize_solver_init_params_type( self, init_params, expectation, poisson_population_GLM_model ): """ - Test the `initialize_solver` method with various types of initial parameters. Ensure that the provided initial parameters - are array-like. + Test the `initialize_solver` method with various types of initial parameters. + + Ensure that the provided initial parameters are array-like. """ X, y, model, true_params, firing_rate = poisson_population_GLM_model with expectation: @@ -2218,7 +2379,9 @@ def test_initialize_solver_x_dimensionality( self, delta_dim, expectation, poisson_population_GLM_model ): """ - Test the `initialize_solver` method with X input data of different dimensionalities. Ensure correct dimensionality for X. + Test the `initialize_solver` method with X input data of different dimensionalities. + + Ensure correct dimensionality for X. """ X, y, model, true_params, firing_rate = poisson_population_GLM_model if delta_dim == -1: @@ -2241,7 +2404,9 @@ def test_initialize_solver_y_dimensionality( self, delta_dim, expectation, poisson_population_GLM_model ): """ - Test the `initialize_solver` method with y target data of different dimensionalities. Ensure correct dimensionality for y. + Test the `initialize_solver` method with y target data of different dimensionalities. + + Ensure correct dimensionality for y. """ X, y, model, true_params, firing_rate = poisson_population_GLM_model if delta_dim == -1: @@ -2318,7 +2483,9 @@ def test_initialize_solver_time_points_x( self, delta_tp, expectation, poisson_population_GLM_model ): """ - Test the `initialize_solver` method for inconsistencies in time-points in data X. Ensure the correct number of time-points. + Test the `initialize_solver` method for inconsistencies in time-points in data X. + + Ensure the correct number of time-points. """ X, y, model, true_params, firing_rate = poisson_population_GLM_model X = jnp.zeros((X.shape[0] + delta_tp,) + X.shape[1:]) @@ -2344,7 +2511,9 @@ def test_initialize_solver_time_points_y( self, delta_tp, expectation, poisson_population_GLM_model ): """ - Test the `initialize_solver` method for inconsistencies in time-points in y. Ensure the correct number of time-points. + Test the `initialize_solver` method for inconsistencies in time-points in y. + + Ensure the correct number of time-points. """ X, y, model, true_params, firing_rate = poisson_population_GLM_model y = jnp.zeros((y.shape[0] + delta_tp,) + y.shape[1:]) @@ -2528,7 +2697,7 @@ def test_score_n_feature_consistency_x( ), ], ) - def test_predict_is_fit(self, is_fit, expectation, poisson_population_GLM_model): + def test_predict_is_fit_population(self, is_fit, expectation, poisson_population_GLM_model): """ Test the `score` method on models based on their fit status. Ensure scoring is only possible on fitted models. diff --git a/tests/test_proximal_operator.py b/tests/test_proximal_operator.py index ad1309b6..59d162bc 100644 --- a/tests/test_proximal_operator.py +++ b/tests/test_proximal_operator.py @@ -1,71 +1,79 @@ -import jax import jax.numpy as jnp +import pytest -from nemos.proximal_operator import _vmap_norm2_masked_2, prox_group_lasso +from nemos.proximal_operator import _vmap_norm2_masked_2, prox_group_lasso, prox_lasso -def test_prox_group_lasso_returns_tuple(example_data_prox_operator): - """Test whether prox_group_lasso returns a tuple.""" - params, alpha, mask, scaling = example_data_prox_operator - updated_params = prox_group_lasso(params, alpha, mask, scaling) - assert isinstance(updated_params, tuple) +@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso]) +def test_prox_operator_returns_tuple(prox_operator, example_data_prox_operator): + """Test whether the proximal operator returns a tuple.""" + args = example_data_prox_operator + args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:]) + params_new = prox_operator(*args) + assert isinstance(params_new, tuple) -def test_prox_group_lasso_returns_tuple_multineuron( - example_data_prox_operator_multineuron, -): - """Test whether the tuple returned by prox_group_lasso has a length of 2.""" - params, alpha, mask, scaling = example_data_prox_operator_multineuron - updated_params = prox_group_lasso(params, alpha, mask, scaling) - assert isinstance(updated_params, tuple) +@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso]) +def test_prox_operator_returns_tuple_multineuron(prox_operator, example_data_prox_operator_multineuron): + """Test whether the tuple returned by the proximal operator has a length of 2.""" + args = example_data_prox_operator_multineuron + args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:]) + params_new = prox_operator(*args) + assert isinstance(params_new, tuple) -def test_prox_group_lasso_tuple_length(example_data_prox_operator): - """Test whether the tuple returned by prox_group_lasso has a length of 2.""" - params, alpha, mask, scaling = example_data_prox_operator - updated_params = prox_group_lasso(params, alpha, mask, scaling) - assert len(updated_params) == 2 +@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso]) +def test_prox_operator_tuple_length(prox_operator, example_data_prox_operator): + """Test whether the tuple returned by the proximal operator has a length of 2.""" + args = example_data_prox_operator + args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:]) + params_new = prox_operator(*args) + assert len(params_new) == 2 -def test_prox_group_lasso_tuple_length_multineuron( - example_data_prox_operator_multineuron, -): - """Test whether the tuple returned by prox_group_lasso has a length of 2.""" - params, alpha, mask, scaling = example_data_prox_operator_multineuron - updated_params = prox_group_lasso(params, alpha, mask, scaling) - assert len(updated_params) == 2 +@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso]) +def test_prox_operator_tuple_length_multineuron(prox_operator, example_data_prox_operator_multineuron): + """Test whether the tuple returned by the proximal operator has a length of 2.""" + args = example_data_prox_operator_multineuron + args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:]) + params_new = prox_operator(*args) + assert len(params_new) == 2 -def test_prox_group_lasso_weights_shape(example_data_prox_operator): - """Test whether the shape of the weights in prox_group_lasso is correct.""" - params, alpha, mask, scaling = example_data_prox_operator - updated_params = prox_group_lasso(params, alpha, mask, scaling) - assert updated_params[0].shape == params[0].shape +@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso]) +def test_prox_operator_weights_shape(prox_operator, example_data_prox_operator): + """Test whether the shape of the weights in the proximal operator is correct.""" + args = example_data_prox_operator + args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:]) + params_new = prox_operator(*args) + assert params_new[0].shape == args[0][0].shape -def test_prox_group_lasso_weights_shape_multineuron( - example_data_prox_operator_multineuron, -): - """Test whether the shape of the weights in prox_group_lasso is correct.""" - params, alpha, mask, scaling = example_data_prox_operator_multineuron - updated_params = prox_group_lasso(params, alpha, mask, scaling) - assert updated_params[0].shape == params[0].shape +@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso]) +def test_prox_operator_weights_shape_multineuron(prox_operator, example_data_prox_operator_multineuron): + """Test whether the shape of the weights in the proximal operator is correct.""" + args = example_data_prox_operator_multineuron + args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:]) + params_new = prox_operator(*args) + assert params_new[0].shape == args[0][0].shape -def test_prox_group_lasso_intercepts_shape(example_data_prox_operator): - """Test whether the shape of the intercepts in prox_group_lasso is correct.""" - params, alpha, mask, scaling = example_data_prox_operator - updated_params = prox_group_lasso(params, alpha, mask, scaling) - assert updated_params[1].shape == params[1].shape +@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso]) +def test_prox_operator_intercepts_shape(prox_operator, example_data_prox_operator): + """Test whether the shape of the intercepts in the proximal operator is correct.""" + args = example_data_prox_operator + args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:]) + params_new = prox_operator(*args) + assert params_new[1].shape == args[0][1].shape -def test_prox_group_lasso_intercepts_shape_multineuron( - example_data_prox_operator_multineuron, -): - """Test whether the shape of the intercepts in prox_group_lasso is correct.""" - params, alpha, mask, scaling = example_data_prox_operator_multineuron - updated_params = prox_group_lasso(params, alpha, mask, scaling) - assert updated_params[1].shape == params[1].shape +@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso]) +def test_prox_operator_intercepts_shape_multineuron(prox_operator, example_data_prox_operator_multineuron): + """Test whether the shape of the intercepts in the proximal operator is correct.""" + args = example_data_prox_operator_multineuron + args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:]) + params_new = prox_operator(*args) + assert params_new[1].shape == args[0][1].shape def test_vmap_norm2_masked_2_returns_array(example_data_prox_operator): @@ -96,9 +104,7 @@ def test_vmap_norm2_masked_2_non_negative(example_data_prox_operator): assert jnp.all(l2_norm >= 0) -def test_vmap_norm2_masked_2_non_negative_multineuron( - example_data_prox_operator_multineuron, -): +def test_vmap_norm2_masked_2_non_negative_multineuron(example_data_prox_operator_multineuron): """Test whether all elements of the result from _vmap_norm2_masked_2 are non-negative.""" params, _, mask, _ = example_data_prox_operator_multineuron l2_norm = _vmap_norm2_masked_2(params[0].T, mask) @@ -113,9 +119,7 @@ def test_prox_operator_shrinks_only_masked(example_data_prox_operator): assert all(params_new[0][i] < params[0][i] for i in [0, 2, 3]) -def test_prox_operator_shrinks_only_masked_multineuron( - example_data_prox_operator_multineuron, -): +def test_prox_operator_shrinks_only_masked_multineuron(example_data_prox_operator_multineuron): params, _, mask, _ = example_data_prox_operator_multineuron mask = mask.at[:, 1].set(jnp.zeros(2)) params_new = prox_group_lasso(params, 0.05, mask) diff --git a/tests/test_regularizer.py b/tests/test_regularizer.py index 4aa8b045..32565d07 100644 --- a/tests/test_regularizer.py +++ b/tests/test_regularizer.py @@ -1,7 +1,4 @@ import copy -import warnings -from contextlib import nullcontext as does_not_raise -from typing import NamedTuple import jax import jax.numpy as jnp @@ -142,7 +139,15 @@ class TestUnRegularized: @pytest.mark.parametrize( "solver_name", - ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1], + [ + "GradientDescent", + "BFGS", + "ProximalGradient", + "AGradientDescent", + 1, + "SVRG", + "ProxSVRG", + ], ) def test_init_solver_name(self, solver_name): """Test UnRegularized acceptable solvers.""" @@ -152,6 +157,8 @@ def test_init_solver_name(self, solver_name): "LBFGSB", "NonlinearCG", "ProximalGradient", + "SVRG", + "ProxSVRG", ] raise_exception = solver_name not in acceptable_solvers @@ -165,7 +172,15 @@ def test_init_solver_name(self, solver_name): @pytest.mark.parametrize( "solver_name", - ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1], + [ + "GradientDescent", + "BFGS", + "ProximalGradient", + "AGradientDescent", + 1, + "SVRG", + "ProxSVRG", + ], ) def test_set_solver_name_allowed(self, solver_name): """Test UnRegularized acceptable solvers.""" @@ -175,6 +190,8 @@ def test_set_solver_name_allowed(self, solver_name): "LBFGS", "NonlinearCG", "ProximalGradient", + "SVRG", + "ProxSVRG", ] regularizer = self.cls() model = nmo.glm.GLM(regularizer=regularizer) @@ -211,7 +228,9 @@ def test_get_params(self): assert regularizer.get_params() == {} - @pytest.mark.parametrize("solver_name", ["GradientDescent", "BFGS"]) + @pytest.mark.parametrize( + "solver_name", ["GradientDescent", "BFGS", "SVRG", "ProxSVRG"] + ) @pytest.mark.parametrize("solver_kwargs", [{"tol": 10**-10}, {"tols": 10**-10}]) def test_init_solver_kwargs(self, solver_name, solver_kwargs): """Test RidgeSolver acceptable kwargs.""" @@ -247,7 +266,8 @@ def test_loss_is_callable(self, loss): nmo.utils.assert_is_callable(model._predict_and_compute_loss, "loss") @pytest.mark.parametrize( - "solver_name", ["GradientDescent", "BFGS", "ProximalGradient"] + "solver_name", + ["GradientDescent", "BFGS", "ProximalGradient", "SVRG", "ProxSVRG"], ) def test_run_solver(self, solver_name, poissonGLM_model_instantiation): """Test that the solver runs.""" @@ -261,7 +281,8 @@ def test_run_solver(self, solver_name, poissonGLM_model_instantiation): model.solver_run((true_params[0] * 0.0, true_params[1]), X, y) @pytest.mark.parametrize( - "solver_name", ["GradientDescent", "BFGS", "ProximalGradient"] + "solver_name", + ["GradientDescent", "BFGS", "ProximalGradient", "SVRG", "ProxSVRG"], ) def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytree): """Test that the solver runs.""" @@ -278,7 +299,8 @@ def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytre y, ) - def test_solver_output_match(self, poissonGLM_model_instantiation): + @pytest.mark.parametrize("solver_name", ["GradientDescent", "SVRG"]) + def test_solver_output_match(self, poissonGLM_model_instantiation, solver_name): """Test that different solvers converge to the same solution.""" jax.config.update("jax_enable_x64", True) X, y, model, true_params, firing_rate = poissonGLM_model_instantiation @@ -286,7 +308,7 @@ def test_solver_output_match(self, poissonGLM_model_instantiation): model.data_type = jnp.float64 # set model params model.regularizer = self.cls() - model.solver_name = "GradientDescent" + model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} model.instantiate_solver() @@ -309,13 +331,15 @@ def test_solver_output_match(self, poissonGLM_model_instantiation): "Convex estimators should converge to the same numerical value." ) - def test_solver_match_sklearn(self, poissonGLM_model_instantiation): + @pytest.mark.parametrize("solver_name", ["GradientDescent", "SVRG"]) + def test_solver_match_sklearn(self, poissonGLM_model_instantiation, solver_name): """Test that different solvers converge to the same solution.""" jax.config.update("jax_enable_x64", True) X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 model.regularizer = self.cls() + model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} model.instantiate_solver() weights_bfgs, intercepts_bfgs = model.solver_run( @@ -329,7 +353,10 @@ def test_solver_match_sklearn(self, poissonGLM_model_instantiation): if (not match_weights) or (not match_intercepts): raise ValueError("UnRegularized GLM estimate does not match sklearn!") - def test_solver_match_sklearn_gamma(self, gammaGLM_model_instantiation): + @pytest.mark.parametrize("solver_name", ["GradientDescent", "SVRG"]) + def test_solver_match_sklearn_gamma( + self, gammaGLM_model_instantiation, solver_name + ): """Test that different solvers converge to the same solution.""" jax.config.update("jax_enable_x64", True) X, y, model, true_params, firing_rate = gammaGLM_model_instantiation @@ -337,6 +364,7 @@ def test_solver_match_sklearn_gamma(self, gammaGLM_model_instantiation): model.data_type = jnp.float64 model.observation_model.inverse_link_function = jnp.exp model.regularizer = self.cls() + model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} model.instantiate_solver() weights_bfgs, intercepts_bfgs = model.solver_run( @@ -357,8 +385,10 @@ def test_solver_match_sklearn_gamma(self, gammaGLM_model_instantiation): (lambda x: 1 / x, sm.families.links.InversePower()), ], ) + # @pytest.mark.parametrize("solver_name", ["LBFGS", "GradientDescent", "SVRG"]) + @pytest.mark.parametrize("solver_name", ["LBFGS", "SVRG"]) def test_solver_match_statsmodels_gamma( - self, inv_link_jax, link_sm, gammaGLM_model_instantiation + self, inv_link_jax, link_sm, gammaGLM_model_instantiation, solver_name ): """Test that different solvers converge to the same solution.""" jax.config.update("jax_enable_x64", True) @@ -367,7 +397,7 @@ def test_solver_match_statsmodels_gamma( model.data_type = jnp.float64 model.observation_model.inverse_link_function = inv_link_jax model.regularizer = self.cls() - model.solver_name = "LBFGS" + model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-13} model.instantiate_solver() weights_bfgs, intercepts_bfgs = model.solver_run( @@ -393,6 +423,8 @@ def test_solver_match_statsmodels_gamma( "LBFGS", "NonlinearCG", "ProximalGradient", + "SVRG", + "ProxSVRG", ], ) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): @@ -407,7 +439,15 @@ class TestRidge: @pytest.mark.parametrize( "solver_name", - ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1], + [ + "GradientDescent", + "BFGS", + "ProximalGradient", + "AGradientDescent", + 1, + "SVRG", + "ProxSVRG", + ], ) def test_init_solver_name(self, solver_name): """Test RidgeSolver acceptable solvers.""" @@ -417,6 +457,8 @@ def test_init_solver_name(self, solver_name): "LBFGS", "NonlinearCG", "ProximalGradient", + "SVRG", + "ProxSVRG", ] raise_exception = solver_name not in acceptable_solvers if raise_exception: @@ -429,7 +471,15 @@ def test_init_solver_name(self, solver_name): @pytest.mark.parametrize( "solver_name", - ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1], + [ + "GradientDescent", + "BFGS", + "ProximalGradient", + "AGradientDescent", + 1, + "SVRG", + "ProxSVRG", + ], ) def test_set_solver_name_allowed(self, solver_name): """Test RidgeSolver acceptable solvers.""" @@ -440,6 +490,8 @@ def test_set_solver_name_allowed(self, solver_name): "LBFGSB", "NonlinearCG", "ProximalGradient", + "SVRG", + "ProxSVRG", ] regularizer = self.cls() model = nmo.glm.GLM(regularizer=regularizer) @@ -452,7 +504,7 @@ def test_set_solver_name_allowed(self, solver_name): else: model.set_params(solver_name=solver_name) - @pytest.mark.parametrize("solver_name", ["GradientDescent", "BFGS"]) + @pytest.mark.parametrize("solver_name", ["GradientDescent", "BFGS", "SVRG"]) @pytest.mark.parametrize("solver_kwargs", [{"tol": 10**-10}, {"tols": 10**-10}]) def test_init_solver_kwargs(self, solver_name, solver_kwargs): """Test Ridge acceptable kwargs.""" @@ -513,7 +565,8 @@ def test_loss_is_callable(self, loss): nmo.utils.assert_is_callable(model._predict_and_compute_loss, "loss") @pytest.mark.parametrize( - "solver_name", ["GradientDescent", "BFGS", "ProximalGradient"] + "solver_name", + ["GradientDescent", "BFGS", "ProximalGradient", "SVRG", "ProxSVRG"], ) def test_run_solver(self, solver_name, poissonGLM_model_instantiation): """Test that the solver runs.""" @@ -527,7 +580,8 @@ def test_run_solver(self, solver_name, poissonGLM_model_instantiation): runner((true_params[0] * 0.0, true_params[1]), X, y) @pytest.mark.parametrize( - "solver_name", ["GradientDescent", "BFGS", "ProximalGradient"] + "solver_name", + ["GradientDescent", "BFGS", "ProximalGradient", "SVRG", "ProxSVRG"], ) def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytree): """Test that the solver runs.""" @@ -544,7 +598,8 @@ def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytre y, ) - def test_solver_output_match(self, poissonGLM_model_instantiation): + @pytest.mark.parametrize("solver_name", ["GradientDescent", "SVRG"]) + def test_solver_output_match(self, poissonGLM_model_instantiation, solver_name): """Test that different solvers converge to the same solution.""" jax.config.update("jax_enable_x64", True) X, y, model, true_params, firing_rate = poissonGLM_model_instantiation @@ -553,7 +608,7 @@ def test_solver_output_match(self, poissonGLM_model_instantiation): # set model params model.regularizer = self.cls() - model.solver_name = "GradientDescent" + model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} model_bfgs = copy.deepcopy(model) @@ -585,6 +640,7 @@ def test_solver_match_sklearn(self, poissonGLM_model_instantiation): model.data_type = jnp.float64 model.regularizer = self.cls() model.solver_kwargs = {"tol": 10**-12} + model.solver_name = "BFGS" runner_bfgs = model.instantiate_solver().solver_run weights_bfgs, intercepts_bfgs = runner_bfgs( @@ -612,6 +668,7 @@ def test_solver_match_sklearn_gamma(self, gammaGLM_model_instantiation): model.regularizer = self.cls() model.solver_kwargs = {"tol": 10**-12} model.regularizer_strength = 0.1 + model.solver_name = "BFGS" runner_bfgs = model.instantiate_solver().solver_run weights_bfgs, intercepts_bfgs = runner_bfgs( (true_params[0] * 0.0, true_params[1]), X, y @@ -650,11 +707,22 @@ class TestLasso: @pytest.mark.parametrize( "solver_name", - ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1], + [ + "GradientDescent", + "BFGS", + "ProximalGradient", + "AGradientDescent", + 1, + "SVRG", + "ProxSVRG", + ], ) def test_init_solver_name(self, solver_name): """Test Lasso acceptable solvers.""" - acceptable_solvers = ["ProximalGradient"] + acceptable_solvers = [ + "ProximalGradient", + "ProxSVRG", + ] raise_exception = solver_name not in acceptable_solvers if raise_exception: with pytest.raises( @@ -666,11 +734,22 @@ def test_init_solver_name(self, solver_name): @pytest.mark.parametrize( "solver_name", - ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1], + [ + "GradientDescent", + "BFGS", + "ProximalGradient", + "AGradientDescent", + 1, + "SVRG", + "ProxSVRG", + ], ) def test_set_solver_name_allowed(self, solver_name): """Test Lasso acceptable solvers.""" - acceptable_solvers = ["ProximalGradient"] + acceptable_solvers = [ + "ProximalGradient", + "ProxSVRG", + ] regularizer = self.cls() model = nmo.glm.GLM(regularizer=regularizer) raise_exception = solver_name not in acceptable_solvers @@ -682,8 +761,9 @@ def test_set_solver_name_allowed(self, solver_name): else: model.set_params(solver_name=solver_name) + @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) @pytest.mark.parametrize("solver_kwargs", [{"tol": 10**-10}, {"tols": 10**-10}]) - def test_init_solver_kwargs(self, solver_kwargs): + def test_init_solver_kwargs(self, solver_kwargs, solver_name): """Test LassoSolver acceptable kwargs.""" regularizer = self.cls() raise_exception = "tols" in list(solver_kwargs.keys()) @@ -691,9 +771,17 @@ def test_init_solver_kwargs(self, solver_kwargs): with pytest.raises( NameError, match="kwargs {'tols'} in solver_kwargs not a kwarg" ): - nmo.glm.GLM(regularizer=regularizer, solver_kwargs=solver_kwargs) + nmo.glm.GLM( + regularizer=regularizer, + solver_name=solver_name, + solver_kwargs=solver_kwargs, + ) else: - nmo.glm.GLM(regularizer=regularizer, solver_kwargs=solver_kwargs) + nmo.glm.GLM( + regularizer=regularizer, + solver_name=solver_name, + solver_kwargs=solver_kwargs, + ) def test_regularizer_strength_none(self): """Assert regularizer strength handled appropriately.""" @@ -733,17 +821,18 @@ def test_loss_callable(self, loss): else: nmo.utils.assert_is_callable(model._predict_and_compute_loss, "loss") - def test_run_solver(self, poissonGLM_model_instantiation): + @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) + def test_run_solver(self, solver_name, poissonGLM_model_instantiation): """Test that the solver runs.""" X, y, model, true_params, firing_rate = poissonGLM_model_instantiation model.regularizer = self.cls() - model.solver_name = "ProximalGradient" + model.solver_name = solver_name runner = model.instantiate_solver().solver_run runner((true_params[0] * 0.0, true_params[1]), X, y) - @pytest.mark.parametrize("solver_name", ["ProximalGradient"]) + @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytree): """Test that the solver runs.""" @@ -759,14 +848,17 @@ def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytre y, ) - def test_solver_match_statsmodels(self, poissonGLM_model_instantiation): + @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) + def test_solver_match_statsmodels( + self, solver_name, poissonGLM_model_instantiation + ): """Test that different solvers converge to the same solution.""" jax.config.update("jax_enable_x64", True) X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 model.regularizer = self.cls() - model.solver_name = "ProximalGradient" + model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} runner = model.instantiate_solver().solver_run @@ -797,10 +889,12 @@ def test_lasso_pytree(self, poissonGLM_model_instantiation_pytree): model.solver_name = "ProximalGradient" model.fit(X, y) + @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) @pytest.mark.parametrize("reg_str", [0.001, 0.01, 0.1, 1, 10]) def test_lasso_pytree_match( self, reg_str, + solver_name, poissonGLM_model_instantiation_pytree, poissonGLM_model_instantiation, ): @@ -813,20 +907,15 @@ def test_lasso_pytree_match( model_array.regularizer_strength = reg_str model.regularizer = nmo.regularizer.Lasso() model_array.regularizer = nmo.regularizer.Lasso() - model.solver_name = "ProximalGradient" - model_array.solver_name = "ProximalGradient" + model.solver_name = solver_name + model_array.solver_name = solver_name model.fit(X, y) model_array.fit(X_array, y) assert np.allclose( np.hstack(jax.tree_util.tree_leaves(model.coef_)), model_array.coef_ ) - @pytest.mark.parametrize( - "solver_name", - [ - "ProximalGradient", - ], - ) + @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation model.regularizer = self.cls() @@ -839,11 +928,22 @@ class TestGroupLasso: @pytest.mark.parametrize( "solver_name", - ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1], + [ + "GradientDescent", + "BFGS", + "ProximalGradient", + "AGradientDescent", + 1, + "SVRG", + "ProxSVRG", + ], ) def test_init_solver_name(self, solver_name): """Test GroupLasso acceptable solvers.""" - acceptable_solvers = ["ProximalGradient"] + acceptable_solvers = [ + "ProximalGradient", + "ProxSVRG", + ] raise_exception = solver_name not in acceptable_solvers # create a valid mask @@ -862,11 +962,22 @@ def test_init_solver_name(self, solver_name): @pytest.mark.parametrize( "solver_name", - ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1], + [ + "GradientDescent", + "BFGS", + "ProximalGradient", + "AGradientDescent", + 1, + "SVRG", + "ProxSVRG", + ], ) def test_set_solver_name_allowed(self, solver_name): """Test GroupLassoSolver acceptable solvers.""" - acceptable_solvers = ["ProximalGradient"] + acceptable_solvers = [ + "ProximalGradient", + "ProxSVRG", + ] # create a valid mask mask = np.zeros((2, 10)) mask[0, :5] = 1 @@ -883,8 +994,9 @@ def test_set_solver_name_allowed(self, solver_name): else: model.set_params(solver_name=solver_name) + @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) @pytest.mark.parametrize("solver_kwargs", [{"tol": 10**-10}, {"tols": 10**-10}]) - def test_init_solver_kwargs(self, solver_kwargs): + def test_init_solver_kwargs(self, solver_name, solver_kwargs): """Test GroupLasso acceptable kwargs.""" raise_exception = "tols" in list(solver_kwargs.keys()) @@ -900,9 +1012,17 @@ def test_init_solver_kwargs(self, solver_kwargs): with pytest.raises( NameError, match="kwargs {'tols'} in solver_kwargs not a kwarg" ): - nmo.glm.GLM(regularizer=regularizer, solver_kwargs=solver_kwargs) + nmo.glm.GLM( + regularizer=regularizer, + solver_name=solver_name, + solver_kwargs=solver_kwargs, + ) else: - nmo.glm.GLM(regularizer=regularizer, solver_kwargs=solver_kwargs) + nmo.glm.GLM( + regularizer=regularizer, + solver_name=solver_name, + solver_kwargs=solver_kwargs, + ) def test_regularizer_strength_none(self): """Assert regularizer strength handled appropriately.""" @@ -950,7 +1070,8 @@ def test_loss_callable(self, loss): else: nmo.utils.assert_is_callable(model._predict_and_compute_loss, "loss") - def test_run_solver(self, poissonGLM_model_instantiation): + @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) + def test_run_solver(self, solver_name, poissonGLM_model_instantiation): """Test that the solver runs.""" X, y, model, true_params, firing_rate = poissonGLM_model_instantiation @@ -962,12 +1083,13 @@ def test_run_solver(self, poissonGLM_model_instantiation): mask = jnp.asarray(mask) model.regularizer = self.cls(mask=mask) - model.solver_name = "ProximalGradient" + model.solver_name = solver_name model.instantiate_solver() model.solver_run((true_params[0] * 0.0, true_params[1]), X, y) - def test_init_solver(self, poissonGLM_model_instantiation): + @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) + def test_init_solver(self, solver_name, poissonGLM_model_instantiation): """Test that the solver initialization returns a state.""" X, y, model, true_params, firing_rate = poissonGLM_model_instantiation @@ -979,7 +1101,7 @@ def test_init_solver(self, poissonGLM_model_instantiation): mask = jnp.asarray(mask) model.regularizer = self.cls(mask=mask) - model.solver_name = "ProximalGradient" + model.solver_name = solver_name model.instantiate_solver() state = model.solver_init_state(true_params, X, y) @@ -992,7 +1114,8 @@ def test_init_solver(self, poissonGLM_model_instantiation): and hasattr(state, "_asdict") ) - def test_update_solver(self, poissonGLM_model_instantiation): + @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) + def test_update_solver(self, solver_name, poissonGLM_model_instantiation): """Test that the solver initialization returns a state.""" X, y, model, true_params, firing_rate = poissonGLM_model_instantiation @@ -1004,11 +1127,17 @@ def test_update_solver(self, poissonGLM_model_instantiation): mask = jnp.asarray(mask) model.regularizer = self.cls(mask=mask) - model.solver_name = "ProximalGradient" + model.solver_name = solver_name model.instantiate_solver() state = model.solver_init_state((true_params[0] * 0.0, true_params[1]), X, y) + + # ProxSVRG needs the full gradient at the anchor point to be initialized + # so here just set it to xs, which is not correct, but fine shape-wise + if solver_name == "ProxSVRG": + state = state._replace(full_grad_at_reference_point=state.reference_point) + params, state = model.solver_update(true_params, state, X, y) # asses that state is a NamedTuple by checking tuple type and the availability of some NamedTuple # specific namespace attributes @@ -1019,7 +1148,7 @@ def test_update_solver(self, poissonGLM_model_instantiation): and hasattr(state, "_asdict") ) # check params struct and shapes - assert jax.tree_util.tree_structure(params) == jax.tree_structure(true_params) + assert jax.tree_util.tree_structure(params) == jax.tree_util.tree_structure(true_params) assert all( jax.tree_util.tree_leaves(params)[k].shape == p.shape for k, p in enumerate(jax.tree_util.tree_leaves(true_params)) @@ -1290,12 +1419,7 @@ def test_mask_none(self, poissonGLM_model_instantiation): model.solver_name = "ProximalGradient" model.fit(X, y) - @pytest.mark.parametrize( - "solver_name", - [ - "ProximalGradient", - ], - ) + @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation model.regularizer = self.cls() diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 4faaa004..70ff315a 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -47,44 +47,6 @@ def test_difference_of_gammas_excit_a(excit_a, expectation): simulation.difference_of_gammas(10, excit_a=excit_a) -@pytest.mark.parametrize( - "inhib_b, expectation", - [ - ( - -1, - pytest.raises(ValueError, match="Gamma parameter [a-z]+_[a,b] must be >0."), - ), - ( - 0, - pytest.raises(ValueError, match="Gamma parameter [a-z]+_[a,b] must be >0."), - ), - (1, does_not_raise()), - ], -) -def test_difference_of_gammas_excit_a(inhib_b, expectation): - with expectation: - simulation.difference_of_gammas(10, inhib_b=inhib_b) - - -@pytest.mark.parametrize( - "excit_b, expectation", - [ - ( - -1, - pytest.raises(ValueError, match="Gamma parameter [a-z]+_[a,b] must be >0."), - ), - ( - 0, - pytest.raises(ValueError, match="Gamma parameter [a-z]+_[a,b] must be >0."), - ), - (1, does_not_raise()), - ], -) -def test_difference_of_gammas_excit_a(excit_b, expectation): - with expectation: - simulation.difference_of_gammas(10, excit_b=excit_b) - - @pytest.mark.parametrize( "upper_percentile, expectation", [ @@ -230,16 +192,16 @@ def test_regress_filter_weights_size( np.zeros((window_size, n_basis_funcs)), ) assert weights.shape[0] == n_neurons_sender, ( - f"First dimension of weights (n_neurons_receiver) does not " - f"match the second dimension of coupling_filters." + "First dimension of weights (n_neurons_receiver) does not " + "match the second dimension of coupling_filters." ) assert weights.shape[1] == n_neurons_receiver, ( - f"Second dimension of weights (n_neuron_sender) does not " - f"match the third dimension of coupling_filters." + "Second dimension of weights (n_neuron_sender) does not " + "match the third dimension of coupling_filters." ) assert weights.shape[2] == n_basis_funcs, ( - f"Third dimension of weights (n_basis_funcs) does not " - f"match the second dimension of eval_basis." + "Third dimension of weights (n_basis_funcs) does not " + "match the second dimension of eval_basis." ) diff --git a/tests/test_solvers.py b/tests/test_solvers.py new file mode 100644 index 00000000..72970397 --- /dev/null +++ b/tests/test_solvers.py @@ -0,0 +1,573 @@ +import inspect +from contextlib import nullcontext as does_not_raise + +import jax +import jaxopt +import numpy as np +import pytest + +import nemos as nmo +from nemos.solvers import SVRG, ProxSVRG, SVRGState +from nemos.tree_utils import pytree_map_and_reduce, tree_l2_norm, tree_slice, tree_sub + + +@pytest.mark.parametrize( + ("regr_setup", "stepsize"), + [ + ("linear_regression", 1e-3), + ("ridge_regression", 1e-4), + ("linear_regression_tree", 1e-4), + ("ridge_regression_tree", 1e-4), + ], +) +def test_svrg_linear_or_ridge_regression(request, regr_setup, stepsize): + jax.config.update("jax_enable_x64", True) + X, y, _, params, loss = request.getfixturevalue(regr_setup) + + param_init = jax.tree_util.tree_map(np.zeros_like, params) + svrg_params, state = SVRG(loss, tol=10**-12, stepsize=stepsize).run( + param_init, X, y + ) + assert pytree_map_and_reduce( + lambda a, b: np.allclose(a, b, atol=10**-5, rtol=0.0), all, params, svrg_params + ) + + +@pytest.mark.parametrize( + "regr_setup", + [ + "linear_regression", + "ridge_regression", + "linear_regression_tree", + "ridge_regression_tree", + ], +) +def test_svrg_init_state_default(request, regr_setup): + jax.config.update("jax_enable_x64", True) + X, y, _, params, loss = request.getfixturevalue(regr_setup) + + param_init = jax.tree_util.tree_map(np.zeros_like, params) + svrg = SVRG(loss) + state = svrg.init_state(param_init, X, y) + + assert state.iter_num == 0 + assert state.key == jax.random.key(123) + assert state.full_grad_at_reference_point is None + assert state.reference_point is not None + + +@pytest.mark.parametrize( + "regr_setup", + [ + "linear_regression", + "ridge_regression", + "linear_regression_tree", + "ridge_regression_tree", + ], +) +def test_svrg_init_state_key(request, regr_setup): + random_key = jax.random.key(1000) + + jax.config.update("jax_enable_x64", True) + X, y, _, params, loss = request.getfixturevalue(regr_setup) + + param_init = jax.tree_util.tree_map(np.zeros_like, params) + svrg = SVRG(loss, key=random_key) + state = svrg.init_state(param_init, X, y) + + assert state.key == random_key + + +@pytest.mark.parametrize( + "regr_setup", + [ + "linear_regression", + "linear_regression_tree", + ], +) +@pytest.mark.parametrize( + "solver_class, prox, prox_lambda", + [(SVRG, None, None), (ProxSVRG, jaxopt.prox.prox_ridge, 0.1)], +) +def test_svrg_update_needs_df_xs(request, regr_setup, solver_class, prox, prox_lambda): + jax.config.update("jax_enable_x64", True) + X, y, _, params, loss = request.getfixturevalue(regr_setup) + + param_init = jax.tree_util.tree_map(np.zeros_like, params) + if prox_lambda is not None: + args = (prox_lambda, X, y) + constr_args = (loss, prox) + else: + args = (X, y) + constr_args = (loss,) + + solver_class = solver_class(*constr_args) + state = solver_class.init_state(param_init, *args) + + with pytest.raises( + ValueError, + match=r"Full gradient at the anchor point \(state\.full_grad_at_reference_point\) has to be set", + ): + _, _ = solver_class.update(param_init, state, *args) + + +@pytest.mark.parametrize( + "regularizer_name, solver_class, mask", + [ + ("Lasso", ProxSVRG, None), + ("GroupLasso", ProxSVRG, np.array([0, 1, 0, 1]).reshape(1, -1).astype(float)), + ("Ridge", SVRG, None), + ("UnRegularized", SVRG, None), + ], +) +def test_svrg_glm_instantiate_solver(regularizer_name, solver_class, mask): + solver_name = solver_class.__name__ + + # only pass mask if it's not None + kwargs = {"solver_name": solver_name} + if mask is not None: + kwargs["mask"] = mask + + glm = nmo.glm.GLM(regularizer=regularizer_name, solver_name=solver_name) + glm.instantiate_solver() + + solver = inspect.getclosurevars(glm._solver_run).nonlocals["solver"] + assert glm.solver_name == solver_name + assert isinstance(solver, solver_class) + + +@pytest.mark.parametrize( + "regularizer_name, solver_name, mask", + [ + ("Lasso", "ProxSVRG", None), + ("GroupLasso", "ProxSVRG", np.array([0, 1, 0, 1]).reshape(1, -1).astype(float)), + ("Ridge", "SVRG", None), + ("UnRegularized", "SVRG", None), + ], +) +@pytest.mark.parametrize("glm_class", [nmo.glm.GLM, nmo.glm.PopulationGLM]) +def test_svrg_glm_passes_solver_kwargs(regularizer_name, solver_name, mask, glm_class): + solver_kwargs = { + "stepsize": np.abs(np.random.randn()), + "maxiter": np.random.randint(1, 100), + } + + # only pass mask if it's not None + kwargs = {} + if mask is not None and glm_class == nmo.glm.PopulationGLM: + kwargs["feature_mask"] = mask + + glm = glm_class( + regularizer=regularizer_name, + solver_name=solver_name, + solver_kwargs=solver_kwargs, + **kwargs, + ) + glm.instantiate_solver() + + solver = inspect.getclosurevars(glm._solver_run).nonlocals["solver"] + assert solver.stepsize == solver_kwargs["stepsize"] + assert solver.maxiter == solver_kwargs["maxiter"] + + +@pytest.mark.parametrize( + "regularizer_name, solver_class, mask", + [ + ("Lasso", ProxSVRG, None), + ( + "GroupLasso", + ProxSVRG, + np.array([[0], [0], [1]]), + ), + ("GroupLasso", ProxSVRG, None), + ("Ridge", SVRG, None), + ("UnRegularized", SVRG, None), + ], +) +@pytest.mark.parametrize( + "glm_class", + [nmo.glm.GLM, nmo.glm.PopulationGLM], +) +def test_svrg_glm_initialize_state( + glm_class, regularizer_name, solver_class, mask, linear_regression +): + X, y, _, _, _ = linear_regression + + if glm_class == nmo.glm.PopulationGLM: + y = np.expand_dims(y, 1) + + # only pass mask if it's not None + kwargs = {} + if mask is not None and glm_class == nmo.glm.PopulationGLM: + kwargs["feature_mask"] = mask + + glm = glm_class( + regularizer=regularizer_name, + solver_name=solver_class.__name__, + observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), + **kwargs, + ) + + init_params = glm.initialize_params(X, y) + state = glm.initialize_state(X, y, init_params) + + assert state.reference_point == init_params + + for f in (glm._solver_init_state, glm._solver_update, glm._solver_run): + assert isinstance(inspect.getclosurevars(f).nonlocals["solver"], solver_class) + assert isinstance(state, SVRGState) + + +@pytest.mark.parametrize( + "regularizer_name, solver_class, mask", + [ + ("Lasso", ProxSVRG, None), + ( + "GroupLasso", + ProxSVRG, + np.array([[0], [0], [1]]), + ), + ("Ridge", SVRG, None), + ("UnRegularized", SVRG, None), + ], +) +@pytest.mark.parametrize( + "glm_class", + [nmo.glm.GLM, nmo.glm.PopulationGLM], +) +def test_svrg_glm_update( + glm_class, regularizer_name, solver_class, mask, linear_regression +): + X, y, _, _, loss = linear_regression + if glm_class == nmo.glm.PopulationGLM: + y = np.expand_dims(y, 1) + + # only pass mask if it's not None + kwargs = {} + if mask is not None and glm_class == nmo.glm.PopulationGLM: + kwargs["feature_mask"] = mask + + glm = glm_class( + regularizer=regularizer_name, + solver_name=solver_class.__name__, + observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), + **kwargs, + ) + + init_params = glm.initialize_params(X, y) + state = glm.initialize_state(X, y, init_params) + + loss_gradient = jax.jit(jax.grad(glm._solver_loss_fun_)) + + # initialize full gradient at the anchor point + state = state._replace( + full_grad_at_reference_point=loss_gradient(init_params, X, y), + ) + + params, state = glm.update(init_params, state, X, y) + + assert state.iter_num == 1 + + +@pytest.mark.parametrize( + "regularizer_name, solver_name, mask", + [ + ("Lasso", "ProxSVRG", None), + ( + "GroupLasso", + "ProxSVRG", + np.array([[0, 1, 0], [0, 0, 1]]).reshape(2, -1).astype(float), + ), + ("GroupLasso", "ProxSVRG", None), + ( + "GroupLasso", + "ProximalGradient", + np.array([[0, 1, 0], [0, 0, 1]]).reshape(2, -1).astype(float), + ), + ("GroupLasso", "ProximalGradient", None), + ("Ridge", "SVRG", None), + ("UnRegularized", "SVRG", None), + ], +) +@pytest.mark.parametrize( + "maxiter", + [3, 50], +) +@pytest.mark.parametrize( + "glm_class", + [nmo.glm.GLM, nmo.glm.PopulationGLM], +) +def test_svrg_glm_fit( + glm_class, + regularizer_name, + solver_name, + mask, + poissonGLM_model_instantiation, + maxiter, +): + X, y, model, (w_true, b_true), rate = poissonGLM_model_instantiation + + # set tolerance to -1 so that doesn't stop the iteration + solver_kwargs = { + "maxiter": maxiter, + "tol": -1.0, + } + + # only pass mask if it's not None + kwargs = {} + if mask is not None: + kwargs["feature_mask"] = mask + + glm = glm_class( + regularizer=regularizer_name, + solver_name=solver_name, + observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), + solver_kwargs=solver_kwargs, + ) + + if isinstance(glm, nmo.glm.PopulationGLM): + y = np.expand_dims(y, 1) + + glm.fit(X, y) + + solver = inspect.getclosurevars(glm._solver_run).nonlocals["solver"] + assert solver.maxiter == maxiter + assert glm.solver_state_.iter_num == maxiter + + +@pytest.mark.parametrize( + "regularizer_name, solver_class, mask", + [ + ("Lasso", ProxSVRG, None), + ("GroupLasso", ProxSVRG, np.array([0, 1, 0]).reshape(1, -1).astype(float)), + ("Ridge", SVRG, None), + ("UnRegularized", SVRG, None), + ], +) +@pytest.mark.parametrize( + "glm_class", + [nmo.glm.GLM, nmo.glm.PopulationGLM], +) +def test_svrg_glm_update_needs_full_grad_at_reference_point( + glm_class, regularizer_name, solver_class, mask, linear_regression +): + X, y, _, _, loss = linear_regression + if glm_class.__name__ == "PopulationGLM": + y = np.expand_dims(y, 1) + + # only pass mask if it's not None + kwargs = {} + if mask is not None and glm_class == nmo.glm.PopulationGLM: + kwargs["feature_mask"] = mask + + glm = glm_class( + regularizer=regularizer_name, + solver_name=solver_class.__name__, + observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), + ) + + with pytest.raises( + ValueError, + match=r"Full gradient at the anchor point \(state\.full_grad_at_reference_point\) has to be set", + ): + params = glm.initialize_params(X, y) + state = glm.initialize_state(X, y, params) + glm.update(params, state, X, y) + + +@pytest.mark.parametrize( + ("regr_setup", "stepsize"), + [ + ("linear_regression", 1e-3), + ("ridge_regression", 1e-4), + ("linear_regression_tree", 1e-4), + ("ridge_regression_tree", 1e-4), + ], +) +def test_svrg_update_converges(request, regr_setup, stepsize): + jax.config.update("jax_enable_x64", True) + X, y, _, analytical_params, loss = request.getfixturevalue(regr_setup) + + loss_grad = jax.jit(jax.grad(loss)) + + N = y.shape[0] + batch_size = 1 + maxiter = 10_000 + tol = 1e-12 + key = jax.random.key(123) + + m = int((N + batch_size - 1) // batch_size) + + solver = SVRG(loss, stepsize=stepsize, batch_size=batch_size) + params = jax.tree_util.tree_map(np.zeros_like, analytical_params) + state = solver.init_state(params, X, y) + + for _ in range(maxiter): + state = state._replace( + full_grad_at_reference_point=loss_grad(params, X, y), + ) + + prev_params = params + for _ in range(m): + key, subkey = jax.random.split(key) + ind = jax.random.randint(subkey, (batch_size,), 0, N) + xi, yi = tree_slice(X, ind), y[ind] + params, state = solver.update(params, state, xi, yi) + + state = state._replace( + reference_point=params, + ) + + _error = tree_l2_norm(tree_sub(params, prev_params)) / tree_l2_norm(prev_params) + if _error < tol: + break + + assert pytree_map_and_reduce( + lambda a, b: np.allclose(a, b, atol=10**-5, rtol=0.0), + all, + analytical_params, + params, + ) + + +@pytest.mark.parametrize( + "regr_setup, to_tuple", + [ + ("linear_regression", True), + ("linear_regression", False), + ("linear_regression_tree", False), + ], +) +@pytest.mark.parametrize( + "prox, prox_lambda", + [ + (jaxopt.prox.prox_none, None), + (jaxopt.prox.prox_ridge, 0.1), + (jaxopt.prox.prox_none, 0.1), + (nmo.proximal_operator.prox_lasso, 0.1), + ], +) +def test_svrg_xk_update_step(request, regr_setup, to_tuple, prox, prox_lambda): + + X, y, true_params, ols_coef, loss_arr = request.getfixturevalue(regr_setup) + + # the loss takes an array, but I want to test with tuples as well + # so make a new loss function that takes a tuple + if to_tuple: + true_params = ( + true_params, + np.zeros(X.shape[1]), + ) + loss = lambda params, X, y: loss_arr(params[0], X, y) + else: + loss = loss_arr + + stepsize = 1e-2 + loss_gradient = jax.jit(jax.grad(loss)) + + # set the initial parameters to zero and + # set the anchor point to a random value that's not just zeros + init_param = jax.tree_util.tree_map(np.zeros_like, true_params) + xs = jax.tree_util.tree_map(lambda x: np.random.randn(*x.shape), true_params) + df_xs = loss_gradient(xs, X, y) + + # sample a mini-batch + key = jax.random.key(123) + key, subkey = jax.random.split(key) + ind = jax.random.randint(subkey, (32,), 0, y.shape[0]) + xi, yi = tree_slice(X, ind), tree_slice(y, ind) + + dfik_xk = loss_gradient(init_param, xi, yi) + dfik_xs = loss_gradient(xs, xi, yi) + + # update if inputs are arrays + def _array_update(dfik_xk, dfik_xs, df_xs, init_param, stepsize): + gk = dfik_xk - dfik_xs + df_xs + next_xk = init_param - stepsize * gk + return next_xk + + # update if inputs are a tuple of arrays + def _tuple_update(dfik_xk, dfik_xs, df_xs, init_param, stepsize): + return tuple( + _array_update(a, b, c, d, stepsize) + for a, b, c, d in zip(dfik_xk, dfik_xs, df_xs, init_param) + ) + + # update if inputs are dicts with either arrays or tuple of arrays as inputs + # behavior is determined by update_fun + def _dict_update(dfik_xk, dfik_xs, df_xs, init_param, stepsize, update_fun): + return { + k: update_fun(dfik_xk[k], dfik_xs[k], df_xs[k], init_param[k], stepsize) + for k in dfik_xk.keys() + } + + if isinstance(true_params, np.ndarray): + next_xk = _array_update(dfik_xk, dfik_xs, df_xs, init_param, stepsize) + elif isinstance(true_params, tuple): + next_xk = _tuple_update(dfik_xk, dfik_xs, df_xs, init_param, stepsize) + elif isinstance(X, dict): + assert ( + set(X.keys()) + == set(dfik_xk.keys()) + == set(dfik_xs.keys()) + == set(df_xs.keys()) + ) + + if isinstance(list(dfik_xk.values())[0], tuple): + update_fun = _tuple_update + else: + update_fun = _array_update + + next_xk = _dict_update( + dfik_xk, dfik_xs, df_xs, init_param, stepsize, update_fun + ) + else: + raise TypeError + + next_xk = prox(next_xk, prox_lambda, scaling=stepsize) + + if prox_lambda is None: + assert prox == jaxopt.prox.prox_none + solver = SVRG(loss) + else: + solver = ProxSVRG(loss, prox) + svrg_next_xk = solver._inner_loop_param_update_step( + init_param, xs, df_xs, stepsize, prox_lambda, xi, yi + ) + + assert pytree_map_and_reduce( + lambda a, b: np.allclose(a, b, atol=10**-5, rtol=0.0), + all, + next_xk, + svrg_next_xk, + ) + + +@pytest.mark.parametrize( + "shapes, expected_context", + [ + [ + (10, 10), + does_not_raise(), + ], + [ + (10, 8), + pytest.raises( + ValueError, + match="All arguments must have the same sized first dimension.", + ), + ], + ], +) +def test_svrg_wrong_shapes(shapes, expected_context): + X = np.random.randn(shapes[0], 3) + y = np.random.randn(shapes[1], 1) + + init_params = np.random.randn(3, 1) + + def loss_fn(params, X, y): + return 1.0 + + with expected_context: + svrg = SVRG(loss_fn) + svrg.run(init_params, X, y) diff --git a/tests/test_tree_utils.py b/tests/test_tree_utils.py index 8fe7ec4f..f4e10431 100644 --- a/tests/test_tree_utils.py +++ b/tests/test_tree_utils.py @@ -1,3 +1,4 @@ +import numpy as np import jax.numpy as jnp import pytest @@ -104,3 +105,22 @@ def test_get_valid_tree(tree, expected): def test_get_valid_multitree(trees, expected): """Test get_valid_multitree function for filtering valid entries across multiple trees.""" assert jnp.array_equal(tree_utils.get_valid_multitree(*trees), expected) + + +@pytest.mark.parametrize("idx", [ + slice(2, 5), # Slice indexing + np.array([1, 3, 5]), # Integer list indexing + np.array([True, False, True, False, True, False, True, False, True, False]), # Boolean array indexing + (slice(1, 3), slice(0, 2)) # Mixed indexing (simple example with slices) +]) +def test_tree_slice(idx): + mydict = { + 'array1': np.random.rand(10, 3), + 'array2': np.random.rand(10, 2), + 'array3': np.random.rand(10, 4), + 'array4': jnp.arange(30).reshape(10, 3) + } + result = tree_utils.tree_slice(mydict, idx) + for key in mydict: + expected = mydict[key][idx] + assert jnp.all(result[key] == expected) diff --git a/tests/test_utils.py b/tests/test_utils.py index 85a72bba..c3babd71 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,42 +10,6 @@ from nemos import utils -@pytest.mark.parametrize( - "arrays, expected_out", - [ - ([jnp.zeros((10, 1)), np.zeros((10, 1))], jnp.zeros((10, 2))), - ([np.zeros((10, 1)), np.zeros((10, 1))], jnp.zeros((10, 2))), - ( - [np.zeros((10, 1)), nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 1)))], - nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2))), - ), - ( - [ - nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 1))), - nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 1))), - ], - nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2))), - ), - ( - [ - nap.TsdTensor(t=np.arange(10), d=np.zeros((10, 1, 2))), - nap.TsdTensor(t=np.arange(10), d=np.zeros((10, 1, 2))), - ], - nap.TsdTensor(t=np.arange(10), d=np.zeros((10, 2, 2))), - ), - ], -) -def test_concatenate_eval(arrays, expected_out): - """Test various combination of arrays and pyapple time series.""" - out = utils.pynapple_concatenate_jax(arrays, axis=1) - if hasattr(expected_out, "times"): - assert np.all(out.d == expected_out.d) - assert np.all(out.t == expected_out.t) - assert np.all(out.time_support.values == expected_out.time_support.values) - else: - assert np.all(out == expected_out) - - @pytest.mark.parametrize( "arrays, expected_out", [