Skip to content

Commit

Permalink
feat(eda): added target analysis given numerical target column
Browse files Browse the repository at this point in the history
feat(eda): added basic numerical target analysis

squash this

feat(eda): added target analysis given numerical target column
  • Loading branch information
Devin Lu committed Apr 21, 2022
1 parent ec857e7 commit c801377
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 29 deletions.
3 changes: 2 additions & 1 deletion dataprep/eda/create_diff_report/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

def create_diff_report(
df_list: Union[List[pd.DataFrame], Dict[str, pd.DataFrame]],
target: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
display: Optional[List[str]] = None,
title: Optional[str] = "DataPrep Report",
Expand Down Expand Up @@ -63,7 +64,7 @@ def create_diff_report(
_suppress_warnings()
cfg = Config.from_dict(display, config)

components = format_diff_report(df_list, cfg, mode, progress)
components = format_diff_report(df_list, cfg, mode, progress, target)

dict_stats = defaultdict(list)

Expand Down
35 changes: 25 additions & 10 deletions dataprep/eda/create_diff_report/diff_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def format_diff_report(
cfg: Config,
mode: Optional[str],
progress: bool = True,
target: Optional[str] = None
) -> Dict[str, Any]:
"""
Format the data and figures needed by create_diff_report
Expand Down Expand Up @@ -110,13 +111,26 @@ def format_diff_report(
if mode == "basic":
# note: we need the type ignore comment for mypy otherwise it complains because
# it doesn't realize that we converted df_list to a list if it's a dictionary
report = format_basic(df_list, cfg) # type: ignore
if target:
validate_target(target, df_list)
report = format_basic(df_list, target, cfg) # type: ignore
else:
raise ValueError(f"Unknown mode: {mode}")
return report

def validate_target(target: str, df_list: List[pd.DataFrame]):
"""
Helper function, verify that target column exists
"""
exists = False
for df in df_list:
if target in df.columns:
exists = True
break
if not exists:
raise ValueError(f'Sorry, {target} is not a valid column')

def format_basic(df_list: List[pd.DataFrame], cfg: Config) -> Dict[str, Any]:
def format_basic(df_list: List[pd.DataFrame], target: Optional[str], cfg: Config) -> Dict[str, Any]:
"""
Format basic version.
Expand Down Expand Up @@ -158,7 +172,7 @@ def format_basic(df_list: List[pd.DataFrame], cfg: Config) -> Dict[str, Any]:
# data = dask.compute(data)
delayed_results.append(data)

res_plots = dask.delayed(_format_plots)(cfg=cfg, df_list=df_list)
res_plots = dask.delayed(_format_plots)(cfg=cfg, df_list=df_list, target=target)

dask_results["df_computations"] = delayed_results
dask_results["plots"] = res_plots
Expand Down Expand Up @@ -211,7 +225,7 @@ def basic_computations(df: EDAFrame, cfg: Config) -> Dict[str, Any]:


def compute_plot_data(
df_list: List[dd.DataFrame], cfg: Config, dtype: Optional[DTypeDef]
pd_list: List[pd.DataFrame], cfg: Config, dtype: Optional[DTypeDef], target: Optional[str]
) -> Intermediate:
"""
Compute function for create_diff_report's plots
Expand All @@ -229,6 +243,10 @@ def compute_plot_data(
"""
# pylint: disable=too-many-branches, too-many-locals

df_list = list(map(to_dask, pd_list))
for i, _ in enumerate(df_list):
df_list[i].columns = df_list[i].columns.astype(str)

dfs = Dfs(df_list)
dfs_cols = dfs.columns.apply("to_list").data

Expand Down Expand Up @@ -277,7 +295,7 @@ def compute_plot_data(
elif is_dtype(dtp, DateTime_v1()):
plot_data.append((col, dtp, dask.compute(*datum), orig)) # workaround

return Intermediate(data=plot_data, stats=stats, visual_type="comparison_grid")
return Intermediate(data=plot_data, stats=stats, visual_type="comparison_grid", target=target, df_list=pd_list)


def _compute_variables(df: EDAFrame, cfg: Config) -> Dict[str, Any]:
Expand Down Expand Up @@ -407,14 +425,11 @@ def _format_variables(df: EDAFrame, cfg: Config, data: Dict[str, Any]) -> Dict[s


def _format_plots(
df_list: Union[List[pd.DataFrame], Dict[str, pd.DataFrame]], cfg: Config
df_list: Union[List[pd.DataFrame], Dict[str, pd.DataFrame]], cfg: Config, target: Optional[str]
) -> Dict[str, Any]:
"""Formatting of plots section"""
df_list = list(map(to_dask, df_list))
for i, _ in enumerate(df_list):
df_list[i].columns = df_list[i].columns.astype(str)

itmdt = compute_plot_data(df_list=df_list, cfg=cfg, dtype=None)
itmdt = compute_plot_data(pd_list=df_list, cfg=cfg, dtype=None, target=target)
return render_diff(itmdt, cfg=cfg)


Expand Down
146 changes: 128 additions & 18 deletions dataprep/eda/diff/render.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
"""
This module implements the visualization for the plot_diff function.
""" # pylint: disable=too-many-lines
from turtle import color
from typing import Any, Dict, List, Tuple, Optional

from sklearn.preprocessing import MinMaxScaler
import math
import numpy as np
import pandas as pd
import dask.array as da
import matplotlib.pyplot as plt
from bokeh.models import (
HoverTool,
Panel,
FactorRange,
)
from bokeh.plotting import Figure, figure
from bokeh.plotting import Figure, figure, show
from bokeh.transform import dodge
from bokeh.layouts import row
from bokeh.models.ranges import Range1d
from bokeh.models import LinearAxis

from ..configs import Config
from ..dtypes import Continuous, DateTime, Nominal, is_dtype
Expand Down Expand Up @@ -78,6 +83,8 @@ def bar_viz(
orig: List[str],
df_labels: List[str],
baseline: int,
target: Optional[str] = None,
df_list: Optional[List[pd.DataFrame]] = None
) -> Figure:
"""
Render a bar chart
Expand All @@ -94,6 +101,12 @@ def bar_viz(
("Source", "@orig"),
]

col1_min = df[0][col].min()
col2_min = df[1][col].min()
col1_max = df[0][col].max()
col2_max = df[1][col].max()
y_inc = 0.05

if show_yticks:
if len(df[baseline]) > 10:
plot_width = 28 * len(df[baseline])
Expand All @@ -106,12 +119,15 @@ def bar_viz(
tools="hover",
x_range=list(df[baseline].index),
y_axis_type=yscale,
y_range=(min(col1_min, col2_min) * (1 - y_inc), max(col1_max, col2_max) * (1 + y_inc))
)

row_names = None
offset = np.linspace(-0.08 * len(df), 0.08 * len(df), len(df)) if len(df) > 1 else [0]
for i, (nrow, data) in enumerate(zip(nrows, df)):
data["pct"] = data[col] / nrow * 100
data.index = [str(val) for val in data.index]
if row_names is None:
row_names = data.index
data["orig"] = orig[i]

fig.vbar(
Expand All @@ -126,7 +142,6 @@ def bar_viz(
tweak_figure(fig, "bar", show_yticks)

fig.yaxis.axis_label = "Count"

x_axis_label = ""
if ttl_grps > len(df[baseline]):
x_axis_label += f"Top {len(df[baseline])} of {ttl_grps} {col}"
Expand All @@ -142,6 +157,21 @@ def bar_viz(

if show_yticks and yscale == "linear":
_format_axis(fig, 0, df[baseline].max(), "y")

df1, df2 = df_list[0], df_list[1]
if target != col and target and col in df1.columns and col in df2.columns:
col1, col2 = df_list[0][col], df_list[1][col]
row_avgs_1 = []
row_avgs_2 = []
for names in row_names:
row_avgs_1.append(df_list[0][target][col1 == names].mean())
row_avgs_2.append(df_list[1][target][col2 == names].mean())

row_avgs_1 = [0 if math.isnan(x) else x for x in row_avgs_1]
row_avgs_2 = [0 if math.isnan(x) else x for x in row_avgs_2]
fig.extra_y_ranges = {"Averages": Range1d(start=min(row_avgs_1 + row_avgs_2) * (1 - y_inc), end=max(row_avgs_1 + row_avgs_2) * (1 + y_inc))}
fig.multi_line([row_names, row_names], [row_avgs_1, row_avgs_2], color=['navy', 'firebrick'], y_range_name="Averages", line_width=4)
fig.add_layout(LinearAxis(y_range_name="Averages"), 'right')
return fig


Expand All @@ -155,28 +185,56 @@ def hist_viz(
show_yticks: bool,
df_labels: List[str],
orig: Optional[List[str]] = None,
target: Optional[str] = None,
df_list: Optional[List[pd.DataFrame]] = None
) -> Figure:
"""
Render a histogram
"""
# pylint: disable=too-many-arguments,too-many-locals

tooltips = [
("Bin", "@intvl"),
("Frequency", "@freq"),
("Percent", "@pct{0.2f}%"),
("Source", "@orig"),
]
df1, df2 = df_list[0], df_list[1]
y_inc = 0.05
tooltips = [
("Bin", "@intvl"),
("Frequency", "@freq"),
("Percent", "@pct{0.2f}%"),
("Source", "@orig"),
]
fig = None

y_start, y_end = None, None
counts_list = []
if target and target != col and col in df1.columns and col in df2.columns:
for hst in hist:
counts, bins = hst
counts_list.append(counts)

counts_min_1 = min(counts_list[0])
counts_min_2 = min(counts_list[1])

counts_max_1 = max(counts_list[0])
counts_max_2 = max(counts_list[1])

y_start, y_end = min(counts_min_1, counts_min_2), max(counts_max_1, counts_max_2)


fig = Figure(
plot_height=plot_height,
plot_width=plot_width,
title=col,
toolbar_location=None,
y_axis_type=yscale,
y_axis_type=yscale
)

bins_list = []
for i, hst in enumerate(hist):
counts, bins = hst
bins_list.append(bins)
if sum(counts) == 0:
fig.rect(x=0, y=0, width=0, height=0)
continue
Expand All @@ -192,16 +250,34 @@ def hist_viz(
}
)
bottom = 0 if yscale == "linear" or df.empty else counts.min() / 2
fig.quad(
source=df,
left="left",
right="right",
bottom=bottom,
alpha=0.5,
top="freq",
fill_color=CATEGORY10[i],
line_color=CATEGORY10[i],
)
if y_start is not None and y_end is not None:
# fig.y_range = (y_start * (1 - y_inc), y_end * (1 + y_inc))
fig.extra_y_ranges = {"Counts": Range1d(start=y_start * (1 - y_inc), end=y_end * (1 + y_inc))}
fig.quad(
source=df,
left="left",
right="right",
bottom=bottom,
alpha=0.5,
top="freq",
fill_color=CATEGORY10[i],
line_color=CATEGORY10[i],
y_range_name="Counts"
)
else:
fig.quad(
source=df,
left="left",
right="right",
bottom=bottom,
alpha=0.5,
top="freq",
fill_color=CATEGORY10[i],
line_color=CATEGORY10[i]
)
# if col == 'LotFrontage':
# breakpoint()

hover = HoverTool(tooltips=tooltips, attachment="vertical", mode="vline")
fig.add_tools(hover)

Expand All @@ -224,6 +300,34 @@ def hist_viz(
fig.xaxis.axis_label = x_axis_label
fig.xaxis.axis_label_standoff = 0

if target and target != col and col in df1.columns and col in df2.columns:
col1, col2 = df1[col], df2[col]
source1, source2 = col1, col2
col1 = col1[~np.isnan(col1)]
col2 = col2[~np.isnan(col2)]
num_bins1 = len(bins_list[0]) - 1
num_bins2 = len(bins_list[1]) - 1
bins_1, bins_2 = bins_list[0], bins_list[1]

df1_source_bins_series = pd.cut(source1, bins=bins_1, labels=False)
df1_bin_averages = [None] * num_bins1

df2_source_bins_series = pd.cut(source2, bins=bins_2, labels=False)
df2_bin_averages = [None] * num_bins2

for b in range(num_bins1):
df1_bin_averages[b] = df1[target][df1_source_bins_series == b].mean()
for b in range(num_bins2):
df2_bin_averages[b] = df2[target][df2_source_bins_series == b].mean()

df1_bin_averages = [0 if math.isnan(x) else x for x in df1_bin_averages]
df2_bin_averages = [0 if math.isnan(x) else x for x in df2_bin_averages]
max_range = max(df1_bin_averages + df2_bin_averages)
min_range = min(df1_bin_averages + df2_bin_averages)

fig.extra_y_ranges['Averages'] = Range1d(start=min_range * (1 - y_inc), end=max_range * (1 + y_inc))
fig.multi_line([bins_1, bins_2], [df1_bin_averages, df2_bin_averages], color=['navy', 'firebrick'], y_range_name="Averages", line_width=4)
fig.add_layout(LinearAxis(y_range_name="Averages", axis_label='Bin Averages'), 'right')
return fig


Expand Down Expand Up @@ -610,6 +714,9 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
nrows = itmdt["stats"]["nrows"]
titles: List[str] = []

df_list = itmdt.df_list
target = itmdt.target

for col, dtp, data, orig in itmdt["data"]:
fig = None
if is_dtype(dtp, Nominal()):
Expand All @@ -626,6 +733,8 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
orig,
df_labels,
baseline if len(df) > 1 else 0,
target,
df_list
)
elif is_dtype(dtp, Continuous()):
if cfg.diff.density:
Expand All @@ -643,6 +752,8 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
False,
df_labels,
orig,
target,
df_list
)
elif is_dtype(dtp, DateTime()):
df, timeunit = data
Expand Down Expand Up @@ -760,7 +871,6 @@ def render_diff(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
cfg
Config instance
"""

if itmdt.visual_type == "comparison_grid":
visual_elem = render_comparison_grid(itmdt, cfg)
if itmdt.visual_type == "comparison_continuous":
Expand Down
5 changes: 5 additions & 0 deletions dataprep/eda/intermediate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
visual_type = kwargs.pop("visual_type")
super().__init__(**kwargs)
self.visual_type = visual_type
if 'target' in kwargs:
self.target = kwargs.pop('target')

if 'df_list' in kwargs:
self.df_list = kwargs.pop('df_list')
else:
raise ValueError("Unsupported initialization")

Expand Down

0 comments on commit c801377

Please sign in to comment.