Skip to content

Commit

Permalink
Updated the plot_regression_line method (#28)
Browse files Browse the repository at this point in the history
* Updated the plot_regression_line method

* Updated method docstring

* Minor update to docstring
  • Loading branch information
hasnainroopawalla authored Jan 14, 2022
1 parent 9fbc366 commit bbaacc8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 21 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setup(
name="showml",
version="1.5.10",
version="1.5.11",
packages=find_packages(exclude="tests"),
description="A Python package of Machine Learning Algorithms implemented from scratch",
long_description=long_description,
Expand Down
4 changes: 2 additions & 2 deletions showml/examples/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from showml.linear_model.regression import LinearRegression
from showml.utils.dataset import Dataset
from showml.losses.metrics import mean_squared_error, r2_score
from showml.utils.plots import plot_regression_line
from showml.utils.plots import plot_actual_vs_predicted
from showml.utils.preprocessing import normalize
from showml.utils.data_loader import load_auto

Expand All @@ -21,4 +21,4 @@
model.fit(dataset, epochs=10000)
model.plot_metrics()

plot_regression_line(X_train, y_train, model.predict(X_train))
plot_actual_vs_predicted(y_train, model.predict(X_train))
25 changes: 7 additions & 18 deletions showml/utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import matplotlib.pyplot as plt
import numpy as np

from showml.utils.exceptions import InvalidShapeError


def generic_metric_plot(metric_name: str, metric_values: List[float]) -> None:
"""Plot the metric values after training (epoch vs metric).
Expand All @@ -19,25 +17,16 @@ def generic_metric_plot(metric_name: str, metric_values: List[float]) -> None:
plt.show()


def plot_regression_line(
X: np.ndarray, y: np.ndarray, z: np.ndarray, xlabel: str = "", ylabel: str = ""
) -> None:
"""Plot the regression line to visualize how well the model fits to the data.
Only works when the entire dataset is 2-dimensional i.e., input data (X) is 1-dimensional.
def plot_actual_vs_predicted(y: np.ndarray, z: np.ndarray) -> None:
"""Generates a scatter plot of the true values and predicted values.
A diagonal line from (0, 0) to (+limit, +limit) indicates a very good fit i.e., the true values and predicted values are almost equal.
Args:
X (np.ndarray): The input data.
y (np.ndarray): The true labels of the input data.
z (np.ndarray): The predicted values for the input data.
xlabel (str, optional): The label corresponding to the input feature name. Defaults to "".
ylabel (str, optional): The label corresponding to the output feature name. Defaults to "".
"""
if X.shape[1] != 1:
raise InvalidShapeError("X must have exactly 1 dimension.")

plt.scatter(X, y, color="red")
plt.plot(X, z, color="blue")
plt.title("Regression Line")
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.scatter(y, z, color="red")
plt.title("Actual vs Predicted")
plt.xlabel("Actual")
plt.ylabel("Predicted")
plt.show()

0 comments on commit bbaacc8

Please sign in to comment.