Skip to content

Commit

Permalink
scatter(): proper linear regression model fit (#37)
Browse files Browse the repository at this point in the history
* scatter(): switch to sns.regplot()

* scatter(): add fit and order argument

* fix flake error
  • Loading branch information
frankenjoe authored Dec 10, 2021
1 parent 40866b5 commit 65a3034
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions audplot/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,19 @@ def scatter(
truth: typing.Union[typing.Sequence, pd.Series],
prediction: typing.Union[typing.Sequence, pd.Series],
*,
fit: bool = False,
order: int = 1,
ax: matplotlib.axes.Axes = None,
):
r"""Scatter plot of truth and predicted values.
Args:
truth: truth values
prediction: predicted values
fit: if ``True``,
fit a regression model relating the x and y variables
order: if greater than 1,
estimate a polynomial regression model (see ``fit``)
ax: pre-existing axes for the plot.
Otherwise, calls :func:`matplotlib.pyplot.gca()` internally
Expand All @@ -518,24 +524,21 @@ def scatter(
>>> truth = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> prediction = [0.1, 0.8, 2.3, 2.4, 3.9, 5, 6.2, 7.1, 7.8, 9, 9]
>>> scatter(truth, prediction)
>>> scatter(truth, prediction, fit=True)
"""
ax = ax or plt.gca()
minimum = min([min(truth), min(prediction)])
maximum = max([max(truth), max(prediction)])
ax.scatter(truth, prediction)
ax.plot(
[minimum, maximum],
[minimum, maximum],
color='r',
sns.regplot(
x=truth,
y=prediction,
fit_reg=fit,
line_kws={'color': 'r'},
order=order,
ax=ax,
)
ax.set_xlim(minimum, maximum)
ax.set_ylim(minimum, maximum)
ax.set_xlabel('Truth')
ax.set_ylabel('Prediction')
ax.grid(alpha=0.4)
ax.set_axisbelow(True)
sns.despine(ax=ax)


Expand Down

0 comments on commit 65a3034

Please sign in to comment.