From e866528e77a5c1691e5ff22822e63ebb927807fc Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Wed, 27 Mar 2024 12:28:20 -0700 Subject: [PATCH] [XLA:GPU] Add a Gemv rewriter pass to convert gemv to gemm with a trivial dimension We need this pass since GemmFusion only accepts gemms. We should run this pass before GemmFusion. After GemmFusion, we should use AlgebraicSimplifier to remove trivial dimension from gemms that are not fused by GemmFusion. PiperOrigin-RevId: 619615263 --- third_party/xla/xla/service/gpu/BUILD | 35 ++++ .../xla/xla/service/gpu/gemv_rewriter.cc | 178 ++++++++++++++++++ .../xla/xla/service/gpu/gemv_rewriter.h | 44 +++++ .../xla/xla/service/gpu/gemv_rewriter_test.cc | 135 +++++++++++++ 4 files changed, 392 insertions(+) create mode 100644 third_party/xla/xla/service/gpu/gemv_rewriter.cc create mode 100644 third_party/xla/xla/service/gpu/gemv_rewriter.h create mode 100644 third_party/xla/xla/service/gpu/gemv_rewriter_test.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index becd9855abdbb7..77c3ad6368341d 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1620,6 +1620,41 @@ xla_cc_test( ], ) +cc_library( + name = "gemv_rewriter", + srcs = ["gemv_rewriter.cc"], + hdrs = ["gemv_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "gemv_rewriter_test", + srcs = ["gemv_rewriter_test.cc"], + deps = [ + ":gemv_rewriter", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "split_k_gemm_rewriter", srcs = ["split_k_gemm_rewriter.cc"], diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter.cc b/third_party/xla/xla/service/gpu/gemv_rewriter.cc new file mode 100644 index 00000000000000..67ffd2b81db172 --- /dev/null +++ b/third_party/xla/xla/service/gpu/gemv_rewriter.cc @@ -0,0 +1,178 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gemv_rewriter.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/shape.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +// Construct a new layout by adding a new minor-most dimension to the input +// layout. For example, {3, 2, 1, 0} is extended to {4, 3, 2, 1, 0}. +// We expect that the input layout is normalized by LayoutNormalizer, so that +// the input layout has a descending ordering. +absl::StatusOr GetLayoutWithNewMinorMostDimension( + const Layout& layout) { + // Check that the layout is normalized. + if (!LayoutUtil::IsMonotonicWithDim0Major(layout)) { + return absl::InvalidArgumentError("Layout is not normalized."); + } + return LayoutUtil::MakeDescendingLayout(layout.minor_to_major_size() + 1); +} + +class GemvRewriterVisitor : public DfsHloRewriteVisitor { + public: + absl::Status HandleDot(HloInstruction* instr) override { + HloDotInstruction* dot = Cast(instr); + const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + + // This pass relies on dot decomposer which ensures that all non-batch + // dimensions are merged into one. + bool lhs_has_non_contracting_dim = + lhs->shape().rank() == + dim_numbers.lhs_batch_dimensions_size() + + dim_numbers.lhs_contracting_dimensions_size() + 1; + bool rhs_has_non_contracting_dim = + rhs->shape().rank() == + dim_numbers.rhs_batch_dimensions_size() + + dim_numbers.rhs_contracting_dimensions_size() + 1; + + // Skip vector-vector multiplication. + if (!lhs_has_non_contracting_dim && !rhs_has_non_contracting_dim) { + return absl::OkStatus(); + } + + if (dot->shape().is_dynamic()) { + return absl::OkStatus(); + } + + changed_ = true; + + HloComputation* computation = dot->parent(); + HloInstruction* new_lhs = lhs; + if (!lhs_has_non_contracting_dim) { + const Shape& lhs_shape = lhs->shape(); + absl::Span lhs_dimensions = lhs_shape.dimensions(); + std::vector new_lhs_dimensions(lhs_dimensions.begin(), + lhs_dimensions.end()); + new_lhs_dimensions.push_back(1); + Shape new_lhs_shape( + lhs_shape.element_type(), new_lhs_dimensions, + absl::InlinedVector(new_lhs_dimensions.size(), false), + /*tuple_shapes=*/{}); + TF_ASSIGN_OR_RETURN( + *new_lhs_shape.mutable_layout(), + GetLayoutWithNewMinorMostDimension(lhs_shape.layout())); + new_lhs = computation->AddInstruction( + HloInstruction::CreateBitcast(new_lhs_shape, lhs)); + } + + HloInstruction* new_rhs = rhs; + if (!rhs_has_non_contracting_dim) { + const Shape& rhs_shape = rhs->shape(); + absl::Span rhs_dimensions = rhs_shape.dimensions(); + std::vector new_rhs_dimensions(rhs_dimensions.begin(), + rhs_dimensions.end()); + new_rhs_dimensions.push_back(1); + Shape new_rhs_shape( + rhs_shape.element_type(), new_rhs_dimensions, + absl::InlinedVector(new_rhs_dimensions.size(), false), + /*tuple_shapes=*/{}); + TF_ASSIGN_OR_RETURN( + *new_rhs_shape.mutable_layout(), + GetLayoutWithNewMinorMostDimension(rhs_shape.layout())); + new_rhs = computation->AddInstruction( + HloInstruction::CreateBitcast(new_rhs_shape, rhs)); + } + + std::vector new_out_dimensions; + new_out_dimensions.reserve(dot->shape().dimensions().size() + 1); + for (int64_t dim_size : dot->shape().dimensions()) { + new_out_dimensions.push_back(dim_size); + } + if (!lhs_has_non_contracting_dim) { + // Insert the trivial dimension before the non-contracting dimension from + // rhs. + int non_contracting_dim_size = new_out_dimensions.back(); + new_out_dimensions[new_out_dimensions.size() - 1] = 1; + new_out_dimensions.push_back(non_contracting_dim_size); + } else { + new_out_dimensions.push_back(1); + } + + Shape new_out_shape( + dot->shape().element_type(), new_out_dimensions, + absl::InlinedVector(new_out_dimensions.size(), false), + /*tuple_shapes=*/{}); + TF_ASSIGN_OR_RETURN( + *new_out_shape.mutable_layout(), + GetLayoutWithNewMinorMostDimension(dot->shape().layout())); + + HloInstruction* new_dot = + computation->AddInstruction(HloInstruction::CreateDot( + new_out_shape, new_lhs, new_rhs, dot->dot_dimension_numbers(), + dot->precision_config())); + HloInstruction* bitcast = computation->AddInstruction( + HloInstruction::CreateBitcast(dot->shape(), new_dot)); + return computation->ReplaceInstruction(dot, bitcast); + } + + bool changed() const { return changed_; } + + private: + bool changed_ = false; +}; + +} // namespace + +absl::StatusOr GemvRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + GemvRewriterVisitor gemv_rewriter; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + TF_RETURN_IF_ERROR(computation->Accept(&gemv_rewriter)); + } + return gemv_rewriter.changed(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter.h b/third_party/xla/xla/service/gpu/gemv_rewriter.h new file mode 100644 index 00000000000000..a041138b8af5c6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/gemv_rewriter.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_GEMV_REWRITER_H_ +#define XLA_SERVICE_GPU_GEMV_REWRITER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Rewrite a matrix-vector or a vector-matrix multiplication into a +// matrix-matrix multiplication with a trivial dimension. For example, +// [m x n] @ [n] is rewritten to [m x n] @ [n x 1], and [n] @ [m x n] is +// rewritten to [n x 1] @ [m x n]. +class GemvRewriter : public HloModulePass { + public: + absl::string_view name() const override { return "gemv-rewriter"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_GEMV_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc b/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc new file mode 100644 index 00000000000000..46aee0aab3fb88 --- /dev/null +++ b/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc @@ -0,0 +1,135 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gemv_rewriter.h" + +#include +#include + +#include +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +class GemvRewriterTest : public HloTestBase {}; + +TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationToGemm) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[32,7] parameter(0) + p1 = f32[7] parameter(1) + ROOT d = f32[32] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"; + + const char* expected = R"() +// CHECK: %[[P0:.*]] = f32[32,7]{1,0} parameter(0) +// CHECK: %[[P1:.*]] = f32[7]{0} parameter(1) +// CHECK: %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P1]]) +// CHECK: %[[DOT:.*]] = f32[32,1]{1,0} dot(%[[P0]], %[[BITCAST]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} +// CHECK: ROOT %[[ROOT:.*]] = f32[32]{0} bitcast(%[[DOT]]) +})"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected); +} + +TEST_F(GemvRewriterTest, RewriteVectorMatrixMultiplicationToGemm) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[7] parameter(0) + p1 = f32[7,32] parameter(1) + ROOT d = f32[32] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} + })"; + + const char* expected = R"() +// CHECK: %[[P0:.*]] = f32[7]{0} parameter(0) +// CHECK: %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P0]]) +// CHECK: %[[P1:.*]] = f32[7,32]{1,0} parameter(1) +// CHECK: %[[DOT:.*]] = f32[1,32]{1,0} dot(%[[BITCAST]], %[[P1]]), lhs_contracting_dims={0}, rhs_contracting_dims={0} +// CHECK: ROOT %[[ROOT:.*]].1 = f32[32]{0} bitcast(%[[DOT]]) +})"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected); +} + +TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationWithBatch) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[2,5,32,7] parameter(0) + p1 = f32[2,5,7] parameter(1) + ROOT d = f32[2,5,32] dot(p0, p1), + lhs_batch_dims={0,1}, rhs_batch_dims={0,1}, + lhs_contracting_dims={3}, rhs_contracting_dims={2} + })"; + + const char* expected = R"() +// CHECK: %[[P0:.*]] = f32[2,5,32,7]{3,2,1,0} parameter(0) +// CHECK: %[[P1:.*]] = f32[2,5,7]{2,1,0} parameter(1) +// CHECK: %[[BITCAST:.*]] = f32[2,5,7,1]{3,2,1,0} bitcast(%[[P1]]) +// CHECK: %[[DOT:.*]] = f32[2,5,32,1]{3,2,1,0} dot(%[[P0]], %[[BITCAST]]), +// CHECK-SAME: lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} +// CHECK: ROOT %[[ROOT:.*]] = f32[2,5,32]{2,1,0} bitcast(%[[DOT]]) +})"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected); +} + +TEST_F(GemvRewriterTest, DotNotRewriteVectorVectorMultiplication) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[7] parameter(0) + p1 = f32[7] parameter(1) + ROOT d = f32[] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} + })"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt); +} + +TEST_F(GemvRewriterTest, DoNotRewriteDotsWithNonNormalizedLayout) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[5,32,7]{2,1,0} parameter(0) + p1 = f32[5,7]{0,1} parameter(1) + ROOT d = f32[5,32]{0,1} dot(p0, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + GemvRewriter rewriter; + absl::StatusOr result = this->RunHloPass(&rewriter, module.get()); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(result.status().message(), "Layout is not normalized."); +} + +} // namespace +} // namespace xla::gpu