Skip to content

Commit

Permalink
accept duplicate column names in table report (#1125)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromedockes authored Oct 24, 2024
1 parent 2e87e10 commit d418554
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 29 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ Bug fixes
fail on polars dataframes when used with the default parameters. This has been
fixed in :pr:`1122` by :user:`Jérôme Dockès <jeromedockes>`.

* The :class:`TableReport` would raise an exception when the input (pandas)
dataframe contained several columns with the same name. This has been fixed in
:pr:`1125` by :user:`Jérôme Dockès <jeromedockes>`.

Release 0.3.1
=============

Expand Down
4 changes: 2 additions & 2 deletions skrub/_column_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def _onehot_encode(df, n_bins):
"""
n_rows, n_cols = sbd.shape(df)
output = np.zeros((n_cols, n_bins, n_rows), dtype=bool)
for col_idx, col_name in enumerate(sbd.column_names(df)):
col = sbd.col(df, col_name)
for col_idx in range(n_cols):
col = sbd.col_by_idx(df, col_idx)
if sbd.is_numeric(col):
col = sbd.to_float32(col)
if _CATEGORICAL_THRESHOLD <= sbd.n_unique(col):
Expand Down
18 changes: 17 additions & 1 deletion skrub/_dataframe/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"is_column_list",
"to_column_list",
"col",
"col_by_idx",
"collect",
#
# Querying and modifying metadata
Expand Down Expand Up @@ -350,7 +351,7 @@ def to_column_list(obj):
if is_column(obj):
return [obj]
if is_dataframe(obj):
return [col(obj, c) for c in column_names(obj)]
return [col_by_idx(obj, idx) for idx in range(shape(obj)[1])]
if not is_column_list(obj):
raise TypeError("obj should be a DataFrame, a Column or a list of Columns.")
return obj
Expand All @@ -371,6 +372,21 @@ def _col_polars(df, col_name):
return df[col_name]


@dispatch
def col_by_idx(df, col_idx):
raise NotImplementedError()


@col_by_idx.specialize("pandas", argument_type="DataFrame")
def _col_by_idx_pandas(df, col_idx):
return df.iloc[:, col_idx]


@col_by_idx.specialize("polars", argument_type="DataFrame")
def _col_by_idx_polars(df, col_idx):
return df[df.columns[col_idx]]


@dispatch
def collect(df):
return df
Expand Down
24 changes: 24 additions & 0 deletions skrub/_dataframe/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ def test_to_column_list(df_module, example_data_dict):
ns.to_column_list(None)


def test_to_column_list_duplicate_columns(pd_module):
df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]})
df.columns = ["a", "a"]
col_list = ns.to_column_list(df)
assert ns.name(col_list[0]) == "a"
assert ns.to_list(col_list[0]) == [1, 2]
assert ns.name(col_list[1]) == "a"
assert ns.to_list(col_list[1]) == [3, 4]


def test_collect(df_module):
assert ns.collect(df_module.example_dataframe) is df_module.example_dataframe
if df_module.name == "polars":
Expand All @@ -236,6 +246,20 @@ def test_collect(df_module):
)


def test_col(df_module):
assert ns.to_list(ns.col(df_module.example_dataframe, "float-col"))[0] == 4.5


def test_col_by_idx(df_module):
assert ns.name(ns.col_by_idx(df_module.example_dataframe, 2)) == "float-col"


def test_col_by_idx_duplicate_columns(pd_module):
df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]})
df.columns = ["a", "a"]
assert ns.to_list(ns.col_by_idx(df, 0)) == [1, 2]


#
# Querying and modifying metadata
# ===============================
Expand Down
12 changes: 8 additions & 4 deletions skrub/_reporting/_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def summarize_dataframe(
"n_rows": n_rows,
"n_columns": n_columns,
"columns": [],
"first_row_dict": _utils.first_row_dict(df) if n_rows else {},
"dataframe_is_empty": not n_rows or not n_columns,
"sample_table": _sample_table.make_table(
df,
Expand All @@ -71,17 +70,22 @@ def summarize_dataframe(
if order_by is not None:
df = sbd.sort(df, by=order_by)
summary["order_by"] = order_by
for position, column_name in enumerate(sbd.column_names(df)):
if order_by is None:
order_by_column = None
else:
order_by_idx = sbd.column_names(df).index(order_by)
order_by_column = sbd.col_by_idx(df, order_by_idx)
for position in range(sbd.shape(df)[1]):
print(
f"Processing column {position + 1: >3} / {n_columns}", end="\r", flush=True
)
summary["columns"].append(
_summarize_column(
sbd.col(df, column_name),
sbd.col_by_idx(df, position),
position,
dataframe_summary=summary,
with_plots=with_plots,
order_by_column=None if order_by is None else sbd.col(df, order_by),
order_by_column=order_by_column,
)
)
print(flush=True)
Expand Down
2 changes: 1 addition & 1 deletion skrub/_reporting/_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def json(self):
str :
The JSON data.
"""
to_remove = ["dataframe", "sample_table", "first_row_dict"]
to_remove = ["dataframe", "sample_table"]
data = {
k: v for k, v in self._summary_without_plots.items() if k not in to_remove
}
Expand Down
5 changes: 0 additions & 5 deletions skrub/_reporting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ def _to_dict_polars(df):
return df.to_dict(as_series=False)


def first_row_dict(dataframe):
first_row = sbd.slice(dataframe, 0, 1)
return {col_name: col[0] for col_name, col in to_dict(first_row).items()}


def top_k_value_counts(column, k):
counts = sbd.value_counts(column)
n_unique = sbd.shape(counts)[0]
Expand Down
28 changes: 12 additions & 16 deletions skrub/_reporting/tests/test_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import pandas as pd
import pytest
import zoneinfo

from skrub import _column_associations
from skrub import _dataframe as sbd
Expand All @@ -24,21 +23,6 @@ def test_summarize(monkeypatch, df_module, air_quality, order_by, with_plots):
assert summary["n_columns"] == 11
assert summary["n_constant_columns"] == 4
assert summary["n_rows"] == 17
assert summary["first_row_dict"] == {
"all_null": None,
"city": "London",
"constant_datetime": datetime.datetime(2024, 7, 5, 12, 17, 29, 427865),
"constant_numeric": 2.7,
"country": "GB",
"date.utc": datetime.datetime(
2019, 6, 13, 0, 0, tzinfo=zoneinfo.ZoneInfo(key="UTC")
),
"loc_with_nulls": None,
"location": "London Westminster",
"parameter": "no2",
"unit": "µg/m³",
"value": 29.0,
}
assert summary["dataframe"] is air_quality
assert summary["dataframe_module"] == df_module.name
assert summary["sample_table"]["start_i"] == (
Expand Down Expand Up @@ -245,3 +229,15 @@ def test_level_names():
assert _sample_table._level_names(idx) == ["the name"]
idx.names = ["a", "b"]
assert _sample_table._level_names(idx) == ["a", "b"]


def test_duplicate_columns(pd_module):
df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]})
df.columns = ["a", "a"]
summary = summarize_dataframe(df)
cols = summary["columns"]
assert len(cols) == 2
assert cols[0]["name"] == "a"
assert cols[0]["mean"] == 1.5
assert cols[1]["name"] == "a"
assert cols[1]["mean"] == 3.5
6 changes: 6 additions & 0 deletions skrub/_reporting/tests/test_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,9 @@ def test_nat(df_module):
col = ToDatetime().fit_transform(col)
df = df_module.make_dataframe({"a": col})
TableReport(df).html()


def test_duplicate_columns(pd_module):
df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]})
df.columns = ["a", "a"]
TableReport(df).html()

0 comments on commit d418554

Please sign in to comment.