Skip to content

Commit

Permalink
df.apply should allow pd.NA from Callables (#961)
Browse files Browse the repository at this point in the history
  • Loading branch information
JanEricNitschke authored Jul 15, 2024
1 parent e78aaca commit 71b2555
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,7 @@ class DataFrame(NDFrame, OpsMixin):
@overload
def apply(
self,
f: Callable[..., S1],
f: Callable[..., S1 | NAType],
axis: AxisIndex = ...,
raw: _bool = ...,
result_type: None = ...,
Expand All @@ -1248,7 +1248,7 @@ class DataFrame(NDFrame, OpsMixin):
@overload
def apply(
self,
f: Callable[..., S1],
f: Callable[..., S1 | NAType],
axis: Axis = ...,
raw: _bool = ...,
args: Any = ...,
Expand Down Expand Up @@ -1309,7 +1309,7 @@ class DataFrame(NDFrame, OpsMixin):
@overload
def apply(
self,
f: Callable[..., S1],
f: Callable[..., S1 | NAType],
raw: _bool = ...,
result_type: None = ...,
args: Any = ...,
Expand Down
9 changes: 9 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
import xarray as xr

from pandas._libs.missing import NAType
from pandas._typing import Scalar

from tests import (
Expand Down Expand Up @@ -578,6 +579,9 @@ def test_types_apply() -> None:
def returns_scalar(x: pd.Series) -> int:
return 2

def returns_scalar_na(x: pd.Series) -> int | NAType:
return 2 if (x < 5).all() else pd.NA

def returns_series(x: pd.Series) -> pd.Series:
return x**2

Expand All @@ -604,6 +608,11 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
check(
assert_type(df.apply(returns_scalar), "pd.Series[int]"), pd.Series, np.integer
)
check(
assert_type(df.apply(returns_scalar_na), "pd.Series[int]"),
pd.Series,
int,
)
check(assert_type(df.apply(returns_series), pd.DataFrame), pd.DataFrame)
check(assert_type(df.apply(returns_listlike_of_3), pd.DataFrame), pd.DataFrame)
check(assert_type(df.apply(returns_dict), pd.Series), pd.Series)
Expand Down

0 comments on commit 71b2555

Please sign in to comment.