Skip to content

Commit

Permalink
BUG: Copy attrs on pd.merge()
Browse files Browse the repository at this point in the history
This uses the same logic as `pd.concat()`: Copy `attrs` only if all
input `attrs` are identical.

I've refactored the handling in __finalize__ from special-casing based on th the method name (previously only "concat") to handling "other" parameters
that have an `input_objs` attribute. This is a more scalable architecture compared to hard-coding method names in __finalize__.

Tests added for `concat()` and `merge()`.

Closes pandas-dev#60351.
  • Loading branch information
timhoffm committed Nov 19, 2024
1 parent 6a7685f commit 8bd828c
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 7 deletions.
4 changes: 2 additions & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6053,8 +6053,8 @@ def __finalize__(self, other, method: str | None = None, **kwargs) -> Self:
assert isinstance(name, str)
object.__setattr__(self, name, getattr(other, name, None))

if method == "concat":
objs = other.objs
elif hasattr(other, "input_objs"):
objs = other.input_objs
# propagate attrs only if all concat arguments have the same attrs
if all(bool(obj.attrs) for obj in objs):
# all concatenate arguments have non-empty attrs
Expand Down
10 changes: 7 additions & 3 deletions pandas/core/reshape/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def _get_result(
result = sample._constructor_from_mgr(mgr, axes=mgr.axes)
result._name = name
return result.__finalize__(
types.SimpleNamespace(objs=objs), method="concat"
types.SimpleNamespace(input_objs=objs), method="concat"
)

# combine as columns in a frame
Expand All @@ -566,7 +566,9 @@ def _get_result(
)
df = cons(data, index=index, copy=False)
df.columns = columns
return df.__finalize__(types.SimpleNamespace(objs=objs), method="concat")
return df.__finalize__(
types.SimpleNamespace(input_objs=objs), method="concat"
)

# combine block managers
else:
Expand Down Expand Up @@ -605,7 +607,9 @@ def _get_result(
)

out = sample._constructor_from_mgr(new_data, axes=new_data.axes)
return out.__finalize__(types.SimpleNamespace(objs=objs), method="concat")
return out.__finalize__(
types.SimpleNamespace(input_objs=objs), method="concat"
)


def new_axes(
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
import datetime
from functools import partial
import types
from typing import (
TYPE_CHECKING,
Literal,
Expand Down Expand Up @@ -1115,7 +1116,9 @@ def get_result(self) -> DataFrame:

self._maybe_restore_index_levels(result)

return result.__finalize__(self, method="merge")
return result.__finalize__(
types.SimpleNamespace(input_objs=[self.left, self.right]), method="merge"
)

@final
@cache_readonly
Expand Down
26 changes: 25 additions & 1 deletion pandas/tests/frame/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_attrs(self):
result = df.rename(columns=str)
assert result.attrs == {"version": 1}

def test_attrs_deepcopy(self):
def test_attrs_is_deepcopy(self):
df = DataFrame({"A": [2, 3]})
assert df.attrs == {}
df.attrs["tags"] = {"spam", "ham"}
Expand All @@ -324,6 +324,30 @@ def test_attrs_deepcopy(self):
assert result.attrs == df.attrs
assert result.attrs["tags"] is not df.attrs["tags"]

def test_attrs_concat(self):
# concat propagates attrs if all input attrs are equal
df1 = DataFrame({"A": [2, 3]})
df1.attrs = {'a': 1, 'b': 2}
df2 = DataFrame({"A": [4, 5]})
df2.attrs = df1.attrs.copy()
df3 = DataFrame({"A": [6, 7]})
df3.attrs = df1.attrs.copy()
assert pd.concat([df1, df2, df3]).attrs == df1.attrs
# concat does not propagate attrs if input attrs are different
df2.attrs = {'c': 3}
assert pd.concat([df1, df2, df3]).attrs == {}

def test_attrs_merge(self):
# merge propagates attrs if all input attrs are equal
df1 = DataFrame({"key": ['a', 'b'], 'val1': [1, 2]})
df1.attrs = {'a': 1, 'b': 2}
df2 = DataFrame({"key": ['a', 'b'], 'val2': [3, 4]})
df2.attrs = df1.attrs.copy()
assert pd.merge(df1, df2).attrs == df1.attrs
# merge does not propagate attrs if input attrs are different
df2.attrs = {'c': 3}
assert pd.merge(df1, df2).attrs == {}

@pytest.mark.parametrize("allows_duplicate_labels", [True, False, None])
def test_set_flags(
self,
Expand Down

0 comments on commit 8bd828c

Please sign in to comment.