diff --git a/modin/core/dataframe/pandas/partitioning/partition_manager.py b/modin/core/dataframe/pandas/partitioning/partition_manager.py index a642f9453f1..904a2809b4f 100644 --- a/modin/core/dataframe/pandas/partitioning/partition_manager.py +++ b/modin/core/dataframe/pandas/partitioning/partition_manager.py @@ -257,8 +257,16 @@ def groupby_reduce( ) else: mapped_partitions = cls.map_partitions(partitions, map_func) + + # Assuming, that the output will not be larger than the input, + # keep the current number of partitions. + num_splits = min(len(partitions), NPartitions.get()) return cls.map_axis_partitions( - axis, mapped_partitions, reduce_func, enumerate_partitions=True + axis, + mapped_partitions, + reduce_func, + enumerate_partitions=True, + num_splits=num_splits, ) @classmethod