From 5aa374cf68f4176f854be3d0f6266a22016d719f Mon Sep 17 00:00:00 2001 From: Kirill Suvorov Date: Wed, 10 Apr 2024 10:10:40 +0000 Subject: [PATCH] Fix tests --- modin/core/dataframe/pandas/dataframe/dataframe.py | 2 +- .../pandas/partitioning/partition_manager.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/modin/core/dataframe/pandas/dataframe/dataframe.py b/modin/core/dataframe/pandas/dataframe/dataframe.py index 3464b7f9f0b..cde44b178a9 100644 --- a/modin/core/dataframe/pandas/dataframe/dataframe.py +++ b/modin/core/dataframe/pandas/dataframe/dataframe.py @@ -2232,7 +2232,7 @@ def map( func, keep_partitioning=True, map_func_args=func_args, - **func_kwargs if func_kwargs is not None else {}, + map_func_kwargs=func_kwargs, ) else: new_partitions = ( diff --git a/modin/core/dataframe/pandas/partitioning/partition_manager.py b/modin/core/dataframe/pandas/partitioning/partition_manager.py index 26154d42056..42b0c6f3c4d 100644 --- a/modin/core/dataframe/pandas/partitioning/partition_manager.py +++ b/modin/core/dataframe/pandas/partitioning/partition_manager.py @@ -494,6 +494,7 @@ def broadcast_axis_partitions( enumerate_partitions=False, lengths=None, apply_func_args=None, + apply_func_kwargs=None, **kwargs, ): """ @@ -529,6 +530,8 @@ def broadcast_axis_partitions( 2. When passing lengths you must explicitly specify `keep_partitioning=False`. apply_func_args : list-like, optional Positional arguments to pass to the `func`. + apply_func_kwargs : dict, optional + Keyword arguments to pass to the `func`. **kwargs : dict Additional options that could be used by different engines. @@ -580,6 +583,7 @@ def broadcast_axis_partitions( left_partitions[i].apply( preprocessed_map_func, *(apply_func_args if apply_func_args else []), + **apply_func_kwargs if apply_func_kwargs is not None else {}, **kw, **({"partition_idx": idx} if enumerate_partitions else {}), **kwargs, @@ -692,6 +696,7 @@ def map_axis_partitions( lengths=None, enumerate_partitions=False, map_func_args=None, + map_func_kwargs=None, **kwargs, ): """ @@ -723,6 +728,8 @@ def map_axis_partitions( Note that `map_func` must be able to accept `partition_idx` kwarg. map_func_args : list-like, optional Positional arguments to pass to the `map_func`. + map_func_kwargs : dict, optional + Keyword arguments for the 'map_func'. **kwargs : dict Additional options that could be used by different engines. @@ -746,6 +753,7 @@ def map_axis_partitions( lengths=lengths, enumerate_partitions=enumerate_partitions, apply_func_args=map_func_args, + apply_func_kwargs=map_func_kwargs, **kwargs, ) @@ -769,9 +777,9 @@ def map_partitions_splitting_by_column( The number of splits by column. map_func : callable Function to apply. - func_args : iterable, optional + map_func_args : iterable, optional Positional arguments for the 'map_func'. - func_kwargs : dict, optional + map_func_kwargs : dict, optional Keyword arguments for the 'map_func'. Returns