Skip to content

Commit

Permalink
refactor: split scatter_plot_2d into 2 functions. (#39)
Browse files Browse the repository at this point in the history
The scatter plot has an interactive and a non-interactive version. The
top-level function now just deciedes which of the two new functions to
call.
  • Loading branch information
mbelak-dtml authored Aug 1, 2023
1 parent 7b902a3 commit 34195cf
Showing 1 changed file with 116 additions and 75 deletions.
191 changes: 116 additions & 75 deletions edvart/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,77 +68,135 @@ def scatter_plot_2d(
x = df[x]
if isinstance(y, str):
y = df[y]
plotting_func = _scatter_plot_2d_interactive if interactive else _scatter_plot_2d_noninteractive
plotting_func(
df=df,
x=x,
y=y,
color_col=color_col,
figsize=figsize,
opacity=opacity,
xlabel=xlabel,
ylabel=ylabel,
show_xticks=show_xticks,
show_yticks=show_yticks,
show_zerolines=show_zerolines,
)

if interactive:
layout = dict(
width=figsize[0] * _INCHES_TO_PIXELS,
height=figsize[1] * _INCHES_TO_PIXELS,
xaxis=dict(showgrid=False, showticklabels=show_xticks, zeroline=show_zerolines),
yaxis=dict(showgrid=False, showticklabels=show_yticks, zeroline=show_zerolines),
legend=dict(title=f"<b>{color_col}</b>"),
)

def _scatter_plot_2d_noninteractive(
df: pd.DataFrame,
x: Union[str, pd.Series, np.ndarray],
y: Union[str, pd.Series, np.ndarray],
color_col: Optional[str] = None,
figsize: Tuple[float, float] = (12, 12),
opacity: float = 0.8,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
show_xticks: bool = False,
show_yticks: bool = False,
show_zerolines: bool = False,
) -> None:
_fig, ax = plt.subplots(figsize=figsize)
if color_col is not None:
is_color_categorical = utils.is_categorical(df[color_col]) or not is_numeric(df[color_col])
if is_color_categorical:
color_categorical = pd.Categorical(df[color_col])
color_codes = color_categorical.codes
else:
color_codes = df[color_col]
scatter = ax.scatter(x, y, c=color_codes, alpha=opacity)

if interactive:
fig = go.Figure(layout=layout)
if is_color_categorical:
df = df.copy()
x_name, y_name = "__edvart_scatter_x", "__edvart_scatter_y"
df[x_name] = x
df[y_name] = y
for group_name, group in df.groupby(color_col):
fig.add_trace(
go.Scatter(
x=group[x_name],
y=group[y_name],
mode="markers",
marker=dict(opacity=opacity),
name=group_name,
text=[
"</br>".join(
f"{col_name}: {df.loc[row, col_name]}"
for col_name in group.columns.drop([x_name, y_name])
)
for row in group.index
],
)
)
else:
if is_color_categorical:
legend_elements = scatter.legend_elements()
ax.legend(legend_elements[0], color_categorical.categories, title=color_col)
else:
cbar = plt.colorbar(scatter)
cbar.ax.set_ylabel(color_col)
else:
ax.scatter(x, y, alpha=opacity)

if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if not show_xticks:
ax.set_xticks([])
if not show_yticks:
ax.set_yticks([])
if not show_zerolines:
ax.grid(False)
plt.show()


def _scatter_plot_2d_interactive(
df: pd.DataFrame,
x: Union[str, pd.Series, np.ndarray],
y: Union[str, pd.Series, np.ndarray],
color_col: Optional[str] = None,
figsize: Tuple[float, float] = (12, 12),
opacity: float = 0.8,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
show_xticks: bool = False,
show_yticks: bool = False,
show_zerolines: bool = False,
) -> None:
layout = go.Layout(
width=figsize[0] * _INCHES_TO_PIXELS,
height=figsize[1] * _INCHES_TO_PIXELS,
xaxis=go.layout.XAxis(
showgrid=False, showticklabels=show_xticks, zeroline=show_zerolines, title=xlabel
),
yaxis=go.layout.YAxis(
showgrid=False, showticklabels=show_yticks, zeroline=show_zerolines, title=ylabel
),
legend=go.layout.Legend(title=f"<b>{color_col}</b>"),
)
fig = go.Figure(layout=layout)
if color_col is not None:
is_color_categorical = utils.is_categorical(df[color_col]) or not is_numeric(df[color_col])
if is_color_categorical:
df = df.copy()
x_name, y_name = "__edvart_scatter_x", "__edvart_scatter_y"
df[x_name] = x
df[y_name] = y
for group_name, group in df.groupby(color_col):
fig.add_trace(
go.Scatter(
x=x,
y=y,
x=group[x_name],
y=group[y_name],
mode="markers",
marker=dict(
color=df[color_col], opacity=opacity, colorbar=dict(title=color_col)
),
marker=dict(opacity=opacity),
name=group_name,
text=[
"</br>".join(
f"{col_name}: {df.loc[row, col_name]}" for col_name in df.columns
f"{col_name}: {df.loc[row, col_name]}"
for col_name in group.columns.drop([x_name, y_name])
)
for row in df.index
for row in group.index
],
),
)
)
else:
if is_color_categorical:
color_categorical = pd.Categorical(df[color_col])
color_codes = color_categorical.codes
else:
color_codes = df[color_col]

fig, ax = plt.subplots(figsize=figsize)
scatter = ax.scatter(x, y, c=color_codes, alpha=opacity)
if is_color_categorical:
legend_elements = scatter.legend_elements()
ax.legend(legend_elements[0], color_categorical.categories, title=color_col)
else:
cbar = plt.colorbar(scatter)
cbar.ax.set_ylabel(color_col)
elif interactive:
fig = go.Figure(
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="markers",
marker=dict(
color=df[color_col], opacity=opacity, colorbar=dict(title=color_col)
),
text=[
"</br>".join(
f"{col_name}: {df.loc[row, col_name]}" for col_name in df.columns
)
for row in df.index
],
),
)
else: # color_col is None
fig.add_trace(
go.Scatter(
x=x,
y=y,
Expand All @@ -149,23 +207,6 @@ def scatter_plot_2d(
for row in df.index
],
),
layout=layout,
)
else:
fig, ax = plt.subplots(figsize=figsize)
ax.scatter(x, y, alpha=opacity)

if interactive:
fig.show()
else:
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if not show_xticks:
ax.set_xticks([])
if not show_yticks:
ax.set_yticks([])
if not show_zerolines:
ax.grid(False)
plt.show()
fig.show()

0 comments on commit 34195cf

Please sign in to comment.