Skip to content

Commit

Permalink
chore: improve error message for ZeroDivisionError in LowessRegression
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Oct 29, 2024
1 parent 9f4d7fa commit 7c8a1e9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
12 changes: 11 additions & 1 deletion sklego/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,17 @@ def predict(self, X):
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
check_is_fitted(self, ["X_", "y_"])

results = np.stack([np.average(self.y_, weights=self._calc_wts(x_i=x_i)) for x_i in X])
try:
results = np.stack([np.average(self.y_, weights=self._calc_wts(x_i=x_i)) for x_i in X])
except ZeroDivisionError:
msg = (
"Weights, resulting from `np.exp(-(distances**2) / self.sigma)`, are all zero. "
"Try to increase the value of `sigma` or to normalize the input data.\n\n"
"`distances` refer to the distance between each sample `x_i` with all the"
"training samples."
)
raise ValueError(msg)

return results


Expand Down
15 changes: 15 additions & 0 deletions tests/test_estimators/test_lowess.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import re

import numpy as np
import pytest
from sklearn.utils.estimator_checks import parametrize_with_checks

from sklego.linear_model import LowessRegression
Expand All @@ -15,3 +18,15 @@ def test_obvious_usecase():
y = np.ones(x.shape)
y_pred = LowessRegression().fit(X, y).predict(X)
assert np.isclose(y, y_pred).all()


def test_custom_error_for_zero_division():
x = np.arange(0, 100)
X = x.reshape(-1, 1)
y = np.ones(x.shape)
estimator = LowessRegression(sigma=1e-10).fit(X, y)

with pytest.raises(
ValueError, match=re.escape("Weights, resulting from `np.exp(-(distances**2) / self.sigma)`, are all zero.")
):
estimator.predict(X[:10] + 0.5)

0 comments on commit 7c8a1e9

Please sign in to comment.