Skip to content

Commit

Permalink
TEST: Added tests for psalsa and derpsalsa weightings
Browse files Browse the repository at this point in the history
  • Loading branch information
derb12 committed Dec 19, 2024
1 parent fcf4cfd commit a4413fd
Show file tree
Hide file tree
Showing 10 changed files with 328 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pybaselines/_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def _derpsalsa(y, baseline, p, k, shape_y, partial_weights):
# since it's faster than performing the square and exp on the full residual
weights = np.full(shape_y, 1 - p, dtype=float)
mask = residual > 0
weights[mask] = p * np.exp(-((residual[mask] / k)**2) / 2)
weights[mask] = p * np.exp(-0.5 * ((residual[mask] / k)**2))
weights *= partial_weights
return weights

Expand Down
16 changes: 12 additions & 4 deletions pybaselines/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,8 @@ def pspline_psalsa(self, data, lam=1e3, p=0.5, k=None, num_knots=100, spline_deg
Raises
------
ValueError
Raised if `p` is not between 0 and 1.
Raised if `p` is not between 0 and 1. Also raised if `k` is not greater
than 0.
See Also
--------
Expand All @@ -1104,6 +1105,8 @@ def pspline_psalsa(self, data, lam=1e3, p=0.5, k=None, num_knots=100, spline_deg
)
if k is None:
k = np.std(y) / 10
else:
k = _check_scalar_variable(k, variable_name='k')
tol_history = np.empty(max_iter + 1)
for i in range(max_iter + 1):
baseline = self.pspline.solve_pspline(y, weight_array)
Expand Down Expand Up @@ -1185,7 +1188,8 @@ def pspline_derpsalsa(self, data, lam=1e2, p=1e-2, k=None, num_knots=100, spline
Raises
------
ValueError
Raised if `p` is not between 0 and 1.
Raised if `p` is not between 0 and 1. Also raised if `k` is not greater
than 0.
See Also
--------
Expand All @@ -1208,6 +1212,8 @@ def pspline_derpsalsa(self, data, lam=1e2, p=1e-2, k=None, num_knots=100, spline
)
if k is None:
k = np.std(y) / 10
else:
k = _check_scalar_variable(k, variable_name='k')

if smooth_half_window is None:
smooth_half_window = self._size // 200
Expand Down Expand Up @@ -2248,7 +2254,8 @@ def pspline_psalsa(data, lam=1e3, p=0.5, k=None, num_knots=100, spline_degree=3,
Raises
------
ValueError
Raised if `p` is not between 0 and 1.
Raised if `p` is not between 0 and 1. Also raised if `k` is not greater
than 0.
See Also
--------
Expand Down Expand Up @@ -2336,7 +2343,8 @@ def pspline_derpsalsa(data, lam=1e2, p=1e-2, k=None, num_knots=100, spline_degre
Raises
------
ValueError
Raised if `p` is not between 0 and 1.
Raised if `p` is not between 0 and 1. Also raised if `k` is not greater
than 0.
See Also
--------
Expand Down
6 changes: 5 additions & 1 deletion pybaselines/two_d/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np

from .. import _weighting
from .._validation import _check_scalar_variable
from ..utils import ParameterWarning, gaussian, relative_difference
from ._algorithm_setup import _Algorithm2D
from ._whittaker_utils import PenalizedSystem2D
Expand Down Expand Up @@ -765,7 +766,8 @@ def pspline_psalsa(self, data, lam=1e3, p=0.5, k=None, num_knots=25, spline_degr
Raises
------
ValueError
Raised if `p` is not between 0 and 1.
Raised if `p` is not between 0 and 1. Also raised if `k` is not greater
than 0.
See Also
--------
Expand All @@ -789,6 +791,8 @@ def pspline_psalsa(self, data, lam=1e3, p=0.5, k=None, num_knots=25, spline_degr
)
if k is None:
k = np.std(y) / 10
else:
k = _check_scalar_variable(k, variable_name='k')
tol_history = np.empty(max_iter + 1)
for i in range(max_iter + 1):
baseline = self.pspline.solve(y, weight_array)
Expand Down
5 changes: 4 additions & 1 deletion pybaselines/two_d/whittaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,8 @@ def psalsa(self, data, lam=1e5, p=0.5, k=None, diff_order=2, max_iter=50, tol=1e
Raises
------
ValueError
Raised if `p` is not between 0 and 1.
Raised if `p` is not between 0 and 1. Also raised if `k` is not greater
than 0.
Notes
-----
Expand All @@ -789,6 +790,8 @@ def psalsa(self, data, lam=1e5, p=0.5, k=None, diff_order=2, max_iter=50, tol=1e
)
if k is None:
k = np.std(y) / 10
else:
k = _check_scalar_variable(k, variable_name='k')

shape = self._shape if self.whittaker_system._using_svd else self._size
tol_history = np.empty(max_iter + 1)
Expand Down
16 changes: 12 additions & 4 deletions pybaselines/whittaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,8 @@ def psalsa(self, data, lam=1e5, p=0.5, k=None, diff_order=2, max_iter=50, tol=1e
Raises
------
ValueError
Raised if `p` is not between 0 and 1.
Raised if `p` is not between 0 and 1. Also raised if `k` is not greater
than 0.
Notes
-----
Expand All @@ -679,6 +680,8 @@ def psalsa(self, data, lam=1e5, p=0.5, k=None, diff_order=2, max_iter=50, tol=1e
y, weight_array = self._setup_whittaker(data, lam, diff_order, weights)
if k is None:
k = np.std(y) / 10
else:
k = _check_scalar_variable(k, variable_name='k')
tol_history = np.empty(max_iter + 1)
for i in range(max_iter + 1):
baseline = self.whittaker_system.solve(
Expand Down Expand Up @@ -758,7 +761,8 @@ def derpsalsa(self, data, lam=1e6, p=0.01, k=None, diff_order=2, max_iter=50, to
Raises
------
ValueError
Raised if `p` is not between 0 and 1.
Raised if `p` is not between 0 and 1. Also raised if `k` is not greater
than 0.
References
----------
Expand All @@ -772,6 +776,8 @@ def derpsalsa(self, data, lam=1e6, p=0.01, k=None, diff_order=2, max_iter=50, to
y, weight_array = self._setup_whittaker(data, lam, diff_order, weights)
if k is None:
k = np.std(y) / 10
else:
k = _check_scalar_variable(k, variable_name='k')
if smooth_half_window is None:
smooth_half_window = self._size // 200
# could pad the data every iteration, but it is ~2-3 times slower and only affects
Expand Down Expand Up @@ -1283,7 +1289,8 @@ def psalsa(data, lam=1e5, p=0.5, k=None, diff_order=2, max_iter=50, tol=1e-3,
Raises
------
ValueError
Raised if `p` is not between 0 and 1.
Raised if `p` is not between 0 and 1. Also raised if `k` is not greater
than 0.
Notes
-----
Expand Down Expand Up @@ -1366,7 +1373,8 @@ def derpsalsa(data, lam=1e6, p=0.01, k=None, diff_order=2, max_iter=50, tol=1e-3
Raises
------
ValueError
Raised if `p` is not between 0 and 1.
Raised if `p` is not between 0 and 1. Also raised if `k` is not greater
than 0.
References
----------
Expand Down
12 changes: 12 additions & 0 deletions tests/test_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,12 @@ def test_whittaker_comparison(self, lam, p, diff_order):
self, whittaker.psalsa, self.y, lam=lam, p=p, diff_order=diff_order
)

@pytest.mark.parametrize('k', (0, -1))
def test_outside_k_fails(self, k):
"""Ensures k values not greater than 0 raise an exception."""
with pytest.raises(ValueError):
self.class_func(self.y, k=k)


class TestPsplineDerpsalsa(IterativeSplineTester):
"""Class for testing pspline_derpsalsa baseline."""
Expand Down Expand Up @@ -497,6 +503,12 @@ def test_whittaker_comparison(self, lam, p, diff_order):
self, whittaker.derpsalsa, self.y, lam=lam, p=p, diff_order=diff_order
)

@pytest.mark.parametrize('k', (0, -1))
def test_outside_k_fails(self, k):
"""Ensures k values not greater than 0 raise an exception."""
with pytest.raises(ValueError):
self.class_func(self.y, k=k)


class TestPsplineMPLS(SplineTester, InputWeightsMixin):
"""Class for testing pspline_mpls baseline."""
Expand Down
Loading

0 comments on commit a4413fd

Please sign in to comment.