diff --git a/modin/core/dataframe/pandas/dataframe/dataframe.py b/modin/core/dataframe/pandas/dataframe/dataframe.py index 21e2356f9b3..5456f28f127 100644 --- a/modin/core/dataframe/pandas/dataframe/dataframe.py +++ b/modin/core/dataframe/pandas/dataframe/dataframe.py @@ -3255,7 +3255,6 @@ def broadcast_apply( axis, other, join_type, - sort=not self.get_axis(axis).equals(other.get_axis(axis)), ) # unwrap list returned by `copartition`. right_parts = right_parts[0] @@ -3681,7 +3680,7 @@ def _check_if_axes_identical(self, other: PandasDataframe, axis: int = 0) -> boo ) and self._get_axis_lengths(axis) == other._get_axis_lengths(axis) def _copartition( - self, axis, other, how, sort, force_repartition=False, fill_value=None + self, axis, other, how, sort=None, force_repartition=False, fill_value=None ): """ Copartition two Modin DataFrames. @@ -3696,8 +3695,9 @@ def _copartition( Other Modin DataFrame(s) to copartition against. how : str How to manage joining the index object ("left", "right", etc.). - sort : bool + sort : bool, default: None Whether sort the joined index or not. + If ``None``, sort is defined in depend on labels equality along the axis. force_repartition : bool, default: False Whether force the repartitioning or not. By default, this method will skip repartitioning if it is possible. This is because @@ -3730,6 +3730,9 @@ def _copartition( self._get_axis_lengths_cache(axis), ) + if sort is None: + sort = not all(self.get_axis(axis).equals(o.get_axis(axis)) for o in other) + self_index = self.get_axis(axis) others_index = [o.get_axis(axis) for o in other] joined_index, make_reindexer = self._join_index_objects( @@ -3860,13 +3863,7 @@ def n_ary_op( 0, right_frames, join_type, - sort=( - not all( - self.get_axis(0).equals(right.get_axis(0)) for right in right_frames - ) - if sort is None - else sort - ), + sort=sort, ) if copartition_along_columns: new_left_frame = self.__constructor__( @@ -3898,14 +3895,7 @@ def n_ary_op( 1, new_right_frames, join_type, - sort=( - not all( - self.get_axis(1).equals(right.get_axis(1)) - for right in new_right_frames - ) - if sort is None - else sort - ), + sort=sort, ) else: joined_columns = self.copy_columns_cache(copy_lengths=True) @@ -3997,7 +3987,7 @@ def _compute_new_widths(): joined_index, partition_sizes_along_axis, ) = self._copartition( - axis.value ^ 1, others, how, sort, force_repartition=False + axis.value ^ 1, others, how, sort=sort, force_repartition=False ) if axis == Axis.COL_WISE: new_lengths = partition_sizes_along_axis