Skip to content

Commit

Permalink
feat: add equal_scale_axes to 2D scatter plot (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbelak-dtml authored Aug 10, 2023
1 parent 521ce53 commit 307c40a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
11 changes: 11 additions & 0 deletions edvart/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def scatter_plot_2d(
show_xticks: bool = False,
show_yticks: bool = False,
show_zerolines: bool = False,
equal_scale_axes: bool = False,
) -> None:
"""Display a 2D scatter plot of x and y, with optional coloring of points by values in a column.
Expand Down Expand Up @@ -60,6 +61,8 @@ def scatter_plot_2d(
Whether to display ticks on the y axis.
show_zerolines : bool (default = False)
Whether to display zero lines.
equal_scale_axes : bool (default = False)
Whether to make the x and y axes have the same scale.
"""
if isinstance(x, str):
x = df[x]
Expand All @@ -78,6 +81,7 @@ def scatter_plot_2d(
show_xticks=show_xticks,
show_yticks=show_yticks,
show_zerolines=show_zerolines,
equal_scale_axes=equal_scale_axes,
)


Expand All @@ -93,6 +97,7 @@ def _scatter_plot_2d_noninteractive(
show_xticks: bool = False,
show_yticks: bool = False,
show_zerolines: bool = False,
equal_scale_axes: bool = False,
) -> None:
_fig, ax = plt.subplots(figsize=figsize)
if color_col is not None:
Expand Down Expand Up @@ -123,6 +128,8 @@ def _scatter_plot_2d_noninteractive(
ax.set_yticks([])
if not show_zerolines:
ax.grid(False)
if equal_scale_axes:
ax.set_aspect("equal", "datalim")
plt.show()


Expand All @@ -138,6 +145,7 @@ def _scatter_plot_2d_interactive(
show_xticks: bool = False,
show_yticks: bool = False,
show_zerolines: bool = False,
equal_scale_axes: bool = False,
) -> None:
layout = go.Layout(
width=figsize[0] * _INCHES_TO_PIXELS,
Expand All @@ -150,6 +158,9 @@ def _scatter_plot_2d_interactive(
),
legend=go.layout.Legend(title=f"<b>{color_col}</b>"),
)
if equal_scale_axes:
layout.yaxis.scaleanchor = "x"
layout.yaxis.scaleratio = 1
fig = go.Figure(layout=layout)
if color_col is not None:
is_color_categorical = utils.is_categorical(df[color_col]) or not is_numeric(df[color_col])
Expand Down
1 change: 1 addition & 0 deletions edvart/report_sections/multivariate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def pca_first_vs_second(
show_xticks=True,
show_yticks=True,
show_zerolines=True,
equal_scale_axes=True,
)

print(f"Explained variance ratio: {pca.explained_variance_ratio_[:2].sum() * 100 :.2f}%")
Expand Down

0 comments on commit 307c40a

Please sign in to comment.