From bbaacc81779437ea2ef09d7869b1f8a824f80353 Mon Sep 17 00:00:00 2001 From: Hasnain Roopawalla <37022937+hasnainroopawalla@users.noreply.github.com> Date: Fri, 14 Jan 2022 18:09:06 +0100 Subject: [PATCH] Updated the plot_regression_line method (#28) * Updated the plot_regression_line method * Updated method docstring * Minor update to docstring --- setup.py | 2 +- showml/examples/linear_regression.py | 4 ++-- showml/utils/plots.py | 25 +++++++------------------ 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/setup.py b/setup.py index dde1907..73ccc29 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name="showml", - version="1.5.10", + version="1.5.11", packages=find_packages(exclude="tests"), description="A Python package of Machine Learning Algorithms implemented from scratch", long_description=long_description, diff --git a/showml/examples/linear_regression.py b/showml/examples/linear_regression.py index 6d53f0a..ca696a7 100644 --- a/showml/examples/linear_regression.py +++ b/showml/examples/linear_regression.py @@ -3,7 +3,7 @@ from showml.linear_model.regression import LinearRegression from showml.utils.dataset import Dataset from showml.losses.metrics import mean_squared_error, r2_score -from showml.utils.plots import plot_regression_line +from showml.utils.plots import plot_actual_vs_predicted from showml.utils.preprocessing import normalize from showml.utils.data_loader import load_auto @@ -21,4 +21,4 @@ model.fit(dataset, epochs=10000) model.plot_metrics() -plot_regression_line(X_train, y_train, model.predict(X_train)) +plot_actual_vs_predicted(y_train, model.predict(X_train)) diff --git a/showml/utils/plots.py b/showml/utils/plots.py index 2467ed2..94fca0e 100644 --- a/showml/utils/plots.py +++ b/showml/utils/plots.py @@ -2,8 +2,6 @@ import matplotlib.pyplot as plt import numpy as np -from showml.utils.exceptions import InvalidShapeError - def generic_metric_plot(metric_name: str, metric_values: List[float]) -> None: """Plot the metric values after training (epoch vs metric). @@ -19,25 +17,16 @@ def generic_metric_plot(metric_name: str, metric_values: List[float]) -> None: plt.show() -def plot_regression_line( - X: np.ndarray, y: np.ndarray, z: np.ndarray, xlabel: str = "", ylabel: str = "" -) -> None: - """Plot the regression line to visualize how well the model fits to the data. - Only works when the entire dataset is 2-dimensional i.e., input data (X) is 1-dimensional. +def plot_actual_vs_predicted(y: np.ndarray, z: np.ndarray) -> None: + """Generates a scatter plot of the true values and predicted values. + A diagonal line from (0, 0) to (+limit, +limit) indicates a very good fit i.e., the true values and predicted values are almost equal. Args: - X (np.ndarray): The input data. y (np.ndarray): The true labels of the input data. z (np.ndarray): The predicted values for the input data. - xlabel (str, optional): The label corresponding to the input feature name. Defaults to "". - ylabel (str, optional): The label corresponding to the output feature name. Defaults to "". """ - if X.shape[1] != 1: - raise InvalidShapeError("X must have exactly 1 dimension.") - - plt.scatter(X, y, color="red") - plt.plot(X, z, color="blue") - plt.title("Regression Line") - plt.xlabel(xlabel) - plt.ylabel(ylabel) + plt.scatter(y, z, color="red") + plt.title("Actual vs Predicted") + plt.xlabel("Actual") + plt.ylabel("Predicted") plt.show()