Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

accept duplicate column names in table report #1125

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ Bug fixes
fail on polars dataframes when used with the default parameters. This has been
fixed in :pr:`1122` by :user:`Jérôme Dockès <jeromedockes>`.

* The :class:`TableReport` would raise an exception when the input (pandas)
dataframe contained several columns with the same name. This has been fixed in
:pr:`1125` by :user:`Jérôme Dockès <jeromedockes>`.

Release 0.3.1
=============

Expand Down
4 changes: 2 additions & 2 deletions skrub/_column_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def _onehot_encode(df, n_bins):
"""
n_rows, n_cols = sbd.shape(df)
output = np.zeros((n_cols, n_bins, n_rows), dtype=bool)
for col_idx, col_name in enumerate(sbd.column_names(df)):
col = sbd.col(df, col_name)
for col_idx in range(n_cols):
col = sbd.col_by_idx(df, col_idx)
if sbd.is_numeric(col):
col = sbd.to_float32(col)
if _CATEGORICAL_THRESHOLD <= sbd.n_unique(col):
Expand Down
18 changes: 17 additions & 1 deletion skrub/_dataframe/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"is_column_list",
"to_column_list",
"col",
"col_by_idx",
"collect",
#
# Querying and modifying metadata
Expand Down Expand Up @@ -350,7 +351,7 @@
if is_column(obj):
return [obj]
if is_dataframe(obj):
return [col(obj, c) for c in column_names(obj)]
return [col_by_idx(obj, idx) for idx in range(shape(obj)[1])]
if not is_column_list(obj):
raise TypeError("obj should be a DataFrame, a Column or a list of Columns.")
return obj
Expand All @@ -371,6 +372,21 @@
return df[col_name]


@dispatch
def col_by_idx(df, col_idx):
raise NotImplementedError()


@col_by_idx.specialize("pandas", argument_type="DataFrame")
def _col_by_idx_pandas(df, col_idx):
return df.iloc[:, col_idx]


@col_by_idx.specialize("polars", argument_type="DataFrame")
def _col_by_idx_polars(df, col_idx):
return df[df.columns[col_idx]]

Check warning on line 387 in skrub/_dataframe/_common.py

View check run for this annotation

Codecov / codecov/patch

skrub/_dataframe/_common.py#L387

Added line #L387 was not covered by tests


@dispatch
def collect(df):
return df
Expand Down
24 changes: 24 additions & 0 deletions skrub/_dataframe/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ def test_to_column_list(df_module, example_data_dict):
ns.to_column_list(None)


def test_to_column_list_duplicate_columns(pd_module):
df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]})
df.columns = ["a", "a"]
col_list = ns.to_column_list(df)
assert ns.name(col_list[0]) == "a"
assert ns.to_list(col_list[0]) == [1, 2]
assert ns.name(col_list[1]) == "a"
assert ns.to_list(col_list[1]) == [3, 4]


def test_collect(df_module):
assert ns.collect(df_module.example_dataframe) is df_module.example_dataframe
if df_module.name == "polars":
Expand All @@ -236,6 +246,20 @@ def test_collect(df_module):
)


def test_col(df_module):
assert ns.to_list(ns.col(df_module.example_dataframe, "float-col"))[0] == 4.5


def test_col_by_idx(df_module):
assert ns.name(ns.col_by_idx(df_module.example_dataframe, 2)) == "float-col"


def test_col_by_idx_duplicate_columns(pd_module):
df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]})
df.columns = ["a", "a"]
assert ns.to_list(ns.col_by_idx(df, 0)) == [1, 2]


#
# Querying and modifying metadata
# ===============================
Expand Down
12 changes: 8 additions & 4 deletions skrub/_reporting/_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
"n_rows": n_rows,
"n_columns": n_columns,
"columns": [],
"first_row_dict": _utils.first_row_dict(df) if n_rows else {},
"dataframe_is_empty": not n_rows or not n_columns,
"sample_table": _sample_table.make_table(
df,
Expand All @@ -71,17 +70,22 @@
if order_by is not None:
df = sbd.sort(df, by=order_by)
summary["order_by"] = order_by
for position, column_name in enumerate(sbd.column_names(df)):
if order_by is None:
order_by_column = None
else:
order_by_idx = sbd.column_names(df).index(order_by)
order_by_column = sbd.col_by_idx(df, order_by_idx)

Check warning on line 77 in skrub/_reporting/_summarize.py

View check run for this annotation

Codecov / codecov/patch

skrub/_reporting/_summarize.py#L76-L77

Added lines #L76 - L77 were not covered by tests
for position in range(sbd.shape(df)[1]):
print(
f"Processing column {position + 1: >3} / {n_columns}", end="\r", flush=True
)
summary["columns"].append(
_summarize_column(
sbd.col(df, column_name),
sbd.col_by_idx(df, position),
position,
dataframe_summary=summary,
with_plots=with_plots,
order_by_column=None if order_by is None else sbd.col(df, order_by),
order_by_column=order_by_column,
)
)
print(flush=True)
Expand Down
2 changes: 1 addition & 1 deletion skrub/_reporting/_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
str :
The JSON data.
"""
to_remove = ["dataframe", "sample_table", "first_row_dict"]
to_remove = ["dataframe", "sample_table"]

Check warning on line 179 in skrub/_reporting/_table_report.py

View check run for this annotation

Codecov / codecov/patch

skrub/_reporting/_table_report.py#L179

Added line #L179 was not covered by tests
data = {
k: v for k, v in self._summary_without_plots.items() if k not in to_remove
}
Expand Down
5 changes: 0 additions & 5 deletions skrub/_reporting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ def _to_dict_polars(df):
return df.to_dict(as_series=False)


def first_row_dict(dataframe):
first_row = sbd.slice(dataframe, 0, 1)
return {col_name: col[0] for col_name, col in to_dict(first_row).items()}


def top_k_value_counts(column, k):
counts = sbd.value_counts(column)
n_unique = sbd.shape(counts)[0]
Expand Down
28 changes: 12 additions & 16 deletions skrub/_reporting/tests/test_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import pandas as pd
import pytest
import zoneinfo

from skrub import _column_associations
from skrub import _dataframe as sbd
Expand All @@ -24,21 +23,6 @@ def test_summarize(monkeypatch, df_module, air_quality, order_by, with_plots):
assert summary["n_columns"] == 11
assert summary["n_constant_columns"] == 4
assert summary["n_rows"] == 17
assert summary["first_row_dict"] == {
"all_null": None,
"city": "London",
"constant_datetime": datetime.datetime(2024, 7, 5, 12, 17, 29, 427865),
"constant_numeric": 2.7,
"country": "GB",
"date.utc": datetime.datetime(
2019, 6, 13, 0, 0, tzinfo=zoneinfo.ZoneInfo(key="UTC")
),
"loc_with_nulls": None,
"location": "London Westminster",
"parameter": "no2",
"unit": "µg/m³",
"value": 29.0,
}
assert summary["dataframe"] is air_quality
assert summary["dataframe_module"] == df_module.name
assert summary["sample_table"]["start_i"] == (
Expand Down Expand Up @@ -245,3 +229,15 @@ def test_level_names():
assert _sample_table._level_names(idx) == ["the name"]
idx.names = ["a", "b"]
assert _sample_table._level_names(idx) == ["a", "b"]


def test_duplicate_columns(pd_module):
df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]})
df.columns = ["a", "a"]
summary = summarize_dataframe(df)
cols = summary["columns"]
assert len(cols) == 2
assert cols[0]["name"] == "a"
assert cols[0]["mean"] == 1.5
assert cols[1]["name"] == "a"
assert cols[1]["mean"] == 3.5
6 changes: 6 additions & 0 deletions skrub/_reporting/tests/test_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,9 @@ def test_nat(df_module):
col = ToDatetime().fit_transform(col)
df = df_module.make_dataframe({"a": col})
TableReport(df).html()


def test_duplicate_columns(pd_module):
df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]})
df.columns = ["a", "a"]
TableReport(df).html()