Skip to content

Commit

Permalink
[xla:NFC] Add a utility function IsCollapsedOrBatchingDim.
Browse files Browse the repository at this point in the history
Use the function in hlo_sharding_util and other places.

PiperOrigin-RevId: 695336795
  • Loading branch information
bixia1 authored and Google-ML-Automation committed Nov 11, 2024
1 parent 9cf35f4 commit e4755ec
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 12 deletions.
1 change: 1 addition & 0 deletions xla/hlo/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ cc_library(
"//xla/hlo/ir:tile_assignment",
"//xla/service:call_graph",
"//xla/service:dot_as_convolution_util",
"//xla/service:gather_scatter_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
21 changes: 11 additions & 10 deletions xla/hlo/utils/hlo_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ limitations under the License.
#include "xla/protobuf_util.h"
#include "xla/service/call_graph.h"
#include "xla/service/dot_as_convolution_util.h"
#include "xla/service/gather_scatter_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/util.h"
Expand Down Expand Up @@ -1513,8 +1514,8 @@ absl::InlinedVector<int64_t, 1> GetGatherScatterOperandPassthroughOperandDims(
absl::InlinedVector<int64_t, 1> passthrough_dims;
int64_t collapsed_or_batching = 0;
for (int64_t i = 0; i < operand_shape.rank(); ++i) {
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
absl::c_linear_search(operand_batching_dims, i)) {
if (IsCollapsedOrBatchingDim(collapsed_or_inserted_dims,
operand_batching_dims, i)) {
collapsed_or_batching++;
continue;
}
Expand Down Expand Up @@ -1546,8 +1547,8 @@ GetGatherScatterOperandPassthroughOutputOrUpdateDims(
absl::InlinedVector<int64_t, 1> passthrough_dims;
int64_t collapsed_or_batching = 0;
for (int64_t i = 0; i < operand_shape.rank(); ++i) {
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
absl::c_linear_search(operand_batching_dims, i)) {
if (IsCollapsedOrBatchingDim(collapsed_or_inserted_dims,
operand_batching_dims, i)) {
collapsed_or_batching++;
continue;
}
Expand Down Expand Up @@ -1581,8 +1582,8 @@ std::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
DimensionVector passthrough_tile(output_or_update_rank, 1);
int64_t collapsed_or_batching = 0;
for (int64_t i = 0; i < operand_shape.rank(); ++i) {
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
absl::c_linear_search(operand_batching_dims, i)) {
if (IsCollapsedOrBatchingDim(collapsed_or_inserted_dims,
operand_batching_dims, i)) {
collapsed_or_batching++;
continue;
}
Expand Down Expand Up @@ -1636,8 +1637,8 @@ std::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
// Relevant dims have shardings passed to the operand.
DimensionVector relevant_output_or_update_dims;
for (int64_t i = 0; i < operand_shape.rank(); ++i) {
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
absl::c_linear_search(operand_batching_dims, i)) {
if (IsCollapsedOrBatchingDim(collapsed_or_inserted_dims,
operand_batching_dims, i)) {
collapsed_or_batching++;
continue;
}
Expand Down Expand Up @@ -1770,8 +1771,8 @@ std::vector<int64_t> GetScatterSliceSize(const Shape& operand_shape,
std::vector<int64_t> slice_size(operand_shape.rank(), 1);
int64_t num_update_window_dims = 0;
for (int64_t i = 0; i < operand_shape.rank(); ++i) {
if (absl::c_linear_search(dnums.inserted_window_dims(), i) ||
absl::c_linear_search(dnums.input_batching_dims(), i)) {
if (IsCollapsedOrBatchingDim(dnums.inserted_window_dims(),
dnums.input_batching_dims(), i)) {
continue;
}
slice_size[i] = update_shape.dimensions(
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gather_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ HloInstruction* CreateGatherLoopAccumulatorInitValue(
accumulator_state_shape_dims.reserve(1 + slice_sizes.size());
accumulator_state_shape_dims.push_back(gather_loop_trip_count);
for (int64_t i = 0; i < slice_sizes.size(); i++) {
if (!absl::c_linear_search(dim_numbers.collapsed_slice_dims(), i) &&
!absl::c_linear_search(dim_numbers.operand_batching_dims(), i)) {
if (!IsCollapsedOrBatchingDim(dim_numbers.collapsed_slice_dims(),
dim_numbers.operand_batching_dims(), i)) {
accumulator_state_shape_dims.push_back(slice_sizes[i]);
}
}
Expand Down
6 changes: 6 additions & 0 deletions xla/service/gather_scatter_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,10 @@ absl::StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
}

bool IsCollapsedOrBatchingDim(absl::Span<const int64_t> collapsed_dims,
absl::Span<const int64_t> batching_dims,
int64_t dim) {
return absl::c_linear_search(collapsed_dims, dim) ||
absl::c_linear_search(batching_dims, dim);
}
} // namespace xla
4 changes: 4 additions & 0 deletions xla/service/gather_scatter_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ absl::StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
absl::Span<const int64_t> operand_batching_dims,
HloInstruction* index_vector, HloInstruction* induction_var);

// Returns true if the given dimension is a collapsed or batching dimension.
bool IsCollapsedOrBatchingDim(absl::Span<const int64_t> collapsed_dims,
absl::Span<const int64_t> batching_dims,
int64_t dim);
} // namespace xla

#endif // XLA_SERVICE_GATHER_SCATTER_UTILS_H_

0 comments on commit e4755ec

Please sign in to comment.