diff --git a/xla/hlo/utils/BUILD b/xla/hlo/utils/BUILD index b4900fc092e4b..c3e47f05dca03 100644 --- a/xla/hlo/utils/BUILD +++ b/xla/hlo/utils/BUILD @@ -119,6 +119,7 @@ cc_library( "//xla/hlo/ir:tile_assignment", "//xla/service:call_graph", "//xla/service:dot_as_convolution_util", + "//xla/service:gather_scatter_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xla/hlo/utils/hlo_sharding_util.cc b/xla/hlo/utils/hlo_sharding_util.cc index def884e17ef0c..1c60203f48d45 100644 --- a/xla/hlo/utils/hlo_sharding_util.cc +++ b/xla/hlo/utils/hlo_sharding_util.cc @@ -54,6 +54,7 @@ limitations under the License. #include "xla/protobuf_util.h" #include "xla/service/call_graph.h" #include "xla/service/dot_as_convolution_util.h" +#include "xla/service/gather_scatter_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -1513,8 +1514,8 @@ absl::InlinedVector GetGatherScatterOperandPassthroughOperandDims( absl::InlinedVector passthrough_dims; int64_t collapsed_or_batching = 0; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(collapsed_or_inserted_dims, i) || - absl::c_linear_search(operand_batching_dims, i)) { + if (IsCollapsedOrBatchingDim(collapsed_or_inserted_dims, + operand_batching_dims, i)) { collapsed_or_batching++; continue; } @@ -1546,8 +1547,8 @@ GetGatherScatterOperandPassthroughOutputOrUpdateDims( absl::InlinedVector passthrough_dims; int64_t collapsed_or_batching = 0; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(collapsed_or_inserted_dims, i) || - absl::c_linear_search(operand_batching_dims, i)) { + if (IsCollapsedOrBatchingDim(collapsed_or_inserted_dims, + operand_batching_dims, i)) { collapsed_or_batching++; continue; } @@ -1581,8 +1582,8 @@ std::optional PassthroughOperandToGatherOutputOrScatterUpdate( DimensionVector passthrough_tile(output_or_update_rank, 1); int64_t collapsed_or_batching = 0; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(collapsed_or_inserted_dims, i) || - absl::c_linear_search(operand_batching_dims, i)) { + if (IsCollapsedOrBatchingDim(collapsed_or_inserted_dims, + operand_batching_dims, i)) { collapsed_or_batching++; continue; } @@ -1636,8 +1637,8 @@ std::optional PassthroughGatherOutputOrScatterUpdateToOperand( // Relevant dims have shardings passed to the operand. DimensionVector relevant_output_or_update_dims; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(collapsed_or_inserted_dims, i) || - absl::c_linear_search(operand_batching_dims, i)) { + if (IsCollapsedOrBatchingDim(collapsed_or_inserted_dims, + operand_batching_dims, i)) { collapsed_or_batching++; continue; } @@ -1770,8 +1771,8 @@ std::vector GetScatterSliceSize(const Shape& operand_shape, std::vector slice_size(operand_shape.rank(), 1); int64_t num_update_window_dims = 0; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(dnums.inserted_window_dims(), i) || - absl::c_linear_search(dnums.input_batching_dims(), i)) { + if (IsCollapsedOrBatchingDim(dnums.inserted_window_dims(), + dnums.input_batching_dims(), i)) { continue; } slice_size[i] = update_shape.dimensions( diff --git a/xla/service/gather_expander.cc b/xla/service/gather_expander.cc index 095cb318b2833..d0190d02d7c22 100644 --- a/xla/service/gather_expander.cc +++ b/xla/service/gather_expander.cc @@ -222,8 +222,8 @@ HloInstruction* CreateGatherLoopAccumulatorInitValue( accumulator_state_shape_dims.reserve(1 + slice_sizes.size()); accumulator_state_shape_dims.push_back(gather_loop_trip_count); for (int64_t i = 0; i < slice_sizes.size(); i++) { - if (!absl::c_linear_search(dim_numbers.collapsed_slice_dims(), i) && - !absl::c_linear_search(dim_numbers.operand_batching_dims(), i)) { + if (!IsCollapsedOrBatchingDim(dim_numbers.collapsed_slice_dims(), + dim_numbers.operand_batching_dims(), i)) { accumulator_state_shape_dims.push_back(slice_sizes[i]); } } diff --git a/xla/service/gather_scatter_utils.cc b/xla/service/gather_scatter_utils.cc index 1b75952475466..98473a0264ecf 100644 --- a/xla/service/gather_scatter_utils.cc +++ b/xla/service/gather_scatter_utils.cc @@ -214,4 +214,10 @@ absl::StatusOr ExpandIndexVectorIntoOperandSpace( return MakeConcatHlo(expanded_index_components, /*dimension=*/0); } +bool IsCollapsedOrBatchingDim(absl::Span collapsed_dims, + absl::Span batching_dims, + int64_t dim) { + return absl::c_linear_search(collapsed_dims, dim) || + absl::c_linear_search(batching_dims, dim); +} } // namespace xla diff --git a/xla/service/gather_scatter_utils.h b/xla/service/gather_scatter_utils.h index cf1368373583b..e73127003261d 100644 --- a/xla/service/gather_scatter_utils.h +++ b/xla/service/gather_scatter_utils.h @@ -62,6 +62,10 @@ absl::StatusOr ExpandIndexVectorIntoOperandSpace( absl::Span operand_batching_dims, HloInstruction* index_vector, HloInstruction* induction_var); +// Returns true if the given dimension is a collapsed or batching dimension. +bool IsCollapsedOrBatchingDim(absl::Span collapsed_dims, + absl::Span batching_dims, + int64_t dim); } // namespace xla #endif // XLA_SERVICE_GATHER_SCATTER_UTILS_H_