diff --git a/CHANGES.rst b/CHANGES.rst index b9a3654a6..35f3ce9e1 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -72,6 +72,10 @@ Bug fixes fail on polars dataframes when used with the default parameters. This has been fixed in :pr:`1122` by :user:`Jérôme Dockès `. +* The :class:`TableReport` would raise an exception when the input (pandas) + dataframe contained several columns with the same name. This has been fixed in + :pr:`1125` by :user:`Jérôme Dockès `. + Release 0.3.1 ============= diff --git a/skrub/_column_associations.py b/skrub/_column_associations.py index 219f224f5..44d87a865 100644 --- a/skrub/_column_associations.py +++ b/skrub/_column_associations.py @@ -104,8 +104,8 @@ def _onehot_encode(df, n_bins): """ n_rows, n_cols = sbd.shape(df) output = np.zeros((n_cols, n_bins, n_rows), dtype=bool) - for col_idx, col_name in enumerate(sbd.column_names(df)): - col = sbd.col(df, col_name) + for col_idx in range(n_cols): + col = sbd.col_by_idx(df, col_idx) if sbd.is_numeric(col): col = sbd.to_float32(col) if _CATEGORICAL_THRESHOLD <= sbd.n_unique(col): diff --git a/skrub/_dataframe/_common.py b/skrub/_dataframe/_common.py index 66b8879a9..05f2d856f 100644 --- a/skrub/_dataframe/_common.py +++ b/skrub/_dataframe/_common.py @@ -40,6 +40,7 @@ "is_column_list", "to_column_list", "col", + "col_by_idx", "collect", # # Querying and modifying metadata @@ -350,7 +351,7 @@ def to_column_list(obj): if is_column(obj): return [obj] if is_dataframe(obj): - return [col(obj, c) for c in column_names(obj)] + return [col_by_idx(obj, idx) for idx in range(shape(obj)[1])] if not is_column_list(obj): raise TypeError("obj should be a DataFrame, a Column or a list of Columns.") return obj @@ -371,6 +372,21 @@ def _col_polars(df, col_name): return df[col_name] +@dispatch +def col_by_idx(df, col_idx): + raise NotImplementedError() + + +@col_by_idx.specialize("pandas", argument_type="DataFrame") +def _col_by_idx_pandas(df, col_idx): + return df.iloc[:, col_idx] + + +@col_by_idx.specialize("polars", argument_type="DataFrame") +def _col_by_idx_polars(df, col_idx): + return df[df.columns[col_idx]] + + @dispatch def collect(df): return df diff --git a/skrub/_dataframe/tests/test_common.py b/skrub/_dataframe/tests/test_common.py index 491ddb029..35accaeab 100644 --- a/skrub/_dataframe/tests/test_common.py +++ b/skrub/_dataframe/tests/test_common.py @@ -228,6 +228,16 @@ def test_to_column_list(df_module, example_data_dict): ns.to_column_list(None) +def test_to_column_list_duplicate_columns(pd_module): + df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]}) + df.columns = ["a", "a"] + col_list = ns.to_column_list(df) + assert ns.name(col_list[0]) == "a" + assert ns.to_list(col_list[0]) == [1, 2] + assert ns.name(col_list[1]) == "a" + assert ns.to_list(col_list[1]) == [3, 4] + + def test_collect(df_module): assert ns.collect(df_module.example_dataframe) is df_module.example_dataframe if df_module.name == "polars": @@ -236,6 +246,20 @@ def test_collect(df_module): ) +def test_col(df_module): + assert ns.to_list(ns.col(df_module.example_dataframe, "float-col"))[0] == 4.5 + + +def test_col_by_idx(df_module): + assert ns.name(ns.col_by_idx(df_module.example_dataframe, 2)) == "float-col" + + +def test_col_by_idx_duplicate_columns(pd_module): + df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]}) + df.columns = ["a", "a"] + assert ns.to_list(ns.col_by_idx(df, 0)) == [1, 2] + + # # Querying and modifying metadata # =============================== diff --git a/skrub/_reporting/_summarize.py b/skrub/_reporting/_summarize.py index 7774cb591..b4495586d 100644 --- a/skrub/_reporting/_summarize.py +++ b/skrub/_reporting/_summarize.py @@ -58,7 +58,6 @@ def summarize_dataframe( "n_rows": n_rows, "n_columns": n_columns, "columns": [], - "first_row_dict": _utils.first_row_dict(df) if n_rows else {}, "dataframe_is_empty": not n_rows or not n_columns, "sample_table": _sample_table.make_table( df, @@ -71,17 +70,22 @@ def summarize_dataframe( if order_by is not None: df = sbd.sort(df, by=order_by) summary["order_by"] = order_by - for position, column_name in enumerate(sbd.column_names(df)): + if order_by is None: + order_by_column = None + else: + order_by_idx = sbd.column_names(df).index(order_by) + order_by_column = sbd.col_by_idx(df, order_by_idx) + for position in range(sbd.shape(df)[1]): print( f"Processing column {position + 1: >3} / {n_columns}", end="\r", flush=True ) summary["columns"].append( _summarize_column( - sbd.col(df, column_name), + sbd.col_by_idx(df, position), position, dataframe_summary=summary, with_plots=with_plots, - order_by_column=None if order_by is None else sbd.col(df, order_by), + order_by_column=order_by_column, ) ) print(flush=True) diff --git a/skrub/_reporting/_table_report.py b/skrub/_reporting/_table_report.py index 72d61a0d0..6cc5f7617 100644 --- a/skrub/_reporting/_table_report.py +++ b/skrub/_reporting/_table_report.py @@ -176,7 +176,7 @@ def json(self): str : The JSON data. """ - to_remove = ["dataframe", "sample_table", "first_row_dict"] + to_remove = ["dataframe", "sample_table"] data = { k: v for k, v in self._summary_without_plots.items() if k not in to_remove } diff --git a/skrub/_reporting/_utils.py b/skrub/_reporting/_utils.py index 1ef8e38ec..d1a249f72 100644 --- a/skrub/_reporting/_utils.py +++ b/skrub/_reporting/_utils.py @@ -29,11 +29,6 @@ def _to_dict_polars(df): return df.to_dict(as_series=False) -def first_row_dict(dataframe): - first_row = sbd.slice(dataframe, 0, 1) - return {col_name: col[0] for col_name, col in to_dict(first_row).items()} - - def top_k_value_counts(column, k): counts = sbd.value_counts(column) n_unique = sbd.shape(counts)[0] diff --git a/skrub/_reporting/tests/test_summarize.py b/skrub/_reporting/tests/test_summarize.py index 86959d4f3..dc7e7cf20 100644 --- a/skrub/_reporting/tests/test_summarize.py +++ b/skrub/_reporting/tests/test_summarize.py @@ -4,7 +4,6 @@ import numpy as np import pandas as pd import pytest -import zoneinfo from skrub import _column_associations from skrub import _dataframe as sbd @@ -24,21 +23,6 @@ def test_summarize(monkeypatch, df_module, air_quality, order_by, with_plots): assert summary["n_columns"] == 11 assert summary["n_constant_columns"] == 4 assert summary["n_rows"] == 17 - assert summary["first_row_dict"] == { - "all_null": None, - "city": "London", - "constant_datetime": datetime.datetime(2024, 7, 5, 12, 17, 29, 427865), - "constant_numeric": 2.7, - "country": "GB", - "date.utc": datetime.datetime( - 2019, 6, 13, 0, 0, tzinfo=zoneinfo.ZoneInfo(key="UTC") - ), - "loc_with_nulls": None, - "location": "London Westminster", - "parameter": "no2", - "unit": "µg/m³", - "value": 29.0, - } assert summary["dataframe"] is air_quality assert summary["dataframe_module"] == df_module.name assert summary["sample_table"]["start_i"] == ( @@ -245,3 +229,15 @@ def test_level_names(): assert _sample_table._level_names(idx) == ["the name"] idx.names = ["a", "b"] assert _sample_table._level_names(idx) == ["a", "b"] + + +def test_duplicate_columns(pd_module): + df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]}) + df.columns = ["a", "a"] + summary = summarize_dataframe(df) + cols = summary["columns"] + assert len(cols) == 2 + assert cols[0]["name"] == "a" + assert cols[0]["mean"] == 1.5 + assert cols[1]["name"] == "a" + assert cols[1]["mean"] == 3.5 diff --git a/skrub/_reporting/tests/test_table_report.py b/skrub/_reporting/tests/test_table_report.py index 5f4adcf83..2fae33397 100644 --- a/skrub/_reporting/tests/test_table_report.py +++ b/skrub/_reporting/tests/test_table_report.py @@ -93,3 +93,9 @@ def test_nat(df_module): col = ToDatetime().fit_transform(col) df = df_module.make_dataframe({"a": col}) TableReport(df).html() + + +def test_duplicate_columns(pd_module): + df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]}) + df.columns = ["a", "a"] + TableReport(df).html()