diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index becd9855abdbb7..77c3ad6368341d 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1620,6 +1620,41 @@ xla_cc_test( ], ) +cc_library( + name = "gemv_rewriter", + srcs = ["gemv_rewriter.cc"], + hdrs = ["gemv_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "gemv_rewriter_test", + srcs = ["gemv_rewriter_test.cc"], + deps = [ + ":gemv_rewriter", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "split_k_gemm_rewriter", srcs = ["split_k_gemm_rewriter.cc"], diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter.cc b/third_party/xla/xla/service/gpu/gemv_rewriter.cc new file mode 100644 index 00000000000000..67ffd2b81db172 --- /dev/null +++ b/third_party/xla/xla/service/gpu/gemv_rewriter.cc @@ -0,0 +1,178 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gemv_rewriter.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/shape.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +// Construct a new layout by adding a new minor-most dimension to the input +// layout. For example, {3, 2, 1, 0} is extended to {4, 3, 2, 1, 0}. +// We expect that the input layout is normalized by LayoutNormalizer, so that +// the input layout has a descending ordering. +absl::StatusOr GetLayoutWithNewMinorMostDimension( + const Layout& layout) { + // Check that the layout is normalized. + if (!LayoutUtil::IsMonotonicWithDim0Major(layout)) { + return absl::InvalidArgumentError("Layout is not normalized."); + } + return LayoutUtil::MakeDescendingLayout(layout.minor_to_major_size() + 1); +} + +class GemvRewriterVisitor : public DfsHloRewriteVisitor { + public: + absl::Status HandleDot(HloInstruction* instr) override { + HloDotInstruction* dot = Cast(instr); + const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + + // This pass relies on dot decomposer which ensures that all non-batch + // dimensions are merged into one. + bool lhs_has_non_contracting_dim = + lhs->shape().rank() == + dim_numbers.lhs_batch_dimensions_size() + + dim_numbers.lhs_contracting_dimensions_size() + 1; + bool rhs_has_non_contracting_dim = + rhs->shape().rank() == + dim_numbers.rhs_batch_dimensions_size() + + dim_numbers.rhs_contracting_dimensions_size() + 1; + + // Skip vector-vector multiplication. + if (!lhs_has_non_contracting_dim && !rhs_has_non_contracting_dim) { + return absl::OkStatus(); + } + + if (dot->shape().is_dynamic()) { + return absl::OkStatus(); + } + + changed_ = true; + + HloComputation* computation = dot->parent(); + HloInstruction* new_lhs = lhs; + if (!lhs_has_non_contracting_dim) { + const Shape& lhs_shape = lhs->shape(); + absl::Span lhs_dimensions = lhs_shape.dimensions(); + std::vector new_lhs_dimensions(lhs_dimensions.begin(), + lhs_dimensions.end()); + new_lhs_dimensions.push_back(1); + Shape new_lhs_shape( + lhs_shape.element_type(), new_lhs_dimensions, + absl::InlinedVector(new_lhs_dimensions.size(), false), + /*tuple_shapes=*/{}); + TF_ASSIGN_OR_RETURN( + *new_lhs_shape.mutable_layout(), + GetLayoutWithNewMinorMostDimension(lhs_shape.layout())); + new_lhs = computation->AddInstruction( + HloInstruction::CreateBitcast(new_lhs_shape, lhs)); + } + + HloInstruction* new_rhs = rhs; + if (!rhs_has_non_contracting_dim) { + const Shape& rhs_shape = rhs->shape(); + absl::Span rhs_dimensions = rhs_shape.dimensions(); + std::vector new_rhs_dimensions(rhs_dimensions.begin(), + rhs_dimensions.end()); + new_rhs_dimensions.push_back(1); + Shape new_rhs_shape( + rhs_shape.element_type(), new_rhs_dimensions, + absl::InlinedVector(new_rhs_dimensions.size(), false), + /*tuple_shapes=*/{}); + TF_ASSIGN_OR_RETURN( + *new_rhs_shape.mutable_layout(), + GetLayoutWithNewMinorMostDimension(rhs_shape.layout())); + new_rhs = computation->AddInstruction( + HloInstruction::CreateBitcast(new_rhs_shape, rhs)); + } + + std::vector new_out_dimensions; + new_out_dimensions.reserve(dot->shape().dimensions().size() + 1); + for (int64_t dim_size : dot->shape().dimensions()) { + new_out_dimensions.push_back(dim_size); + } + if (!lhs_has_non_contracting_dim) { + // Insert the trivial dimension before the non-contracting dimension from + // rhs. + int non_contracting_dim_size = new_out_dimensions.back(); + new_out_dimensions[new_out_dimensions.size() - 1] = 1; + new_out_dimensions.push_back(non_contracting_dim_size); + } else { + new_out_dimensions.push_back(1); + } + + Shape new_out_shape( + dot->shape().element_type(), new_out_dimensions, + absl::InlinedVector(new_out_dimensions.size(), false), + /*tuple_shapes=*/{}); + TF_ASSIGN_OR_RETURN( + *new_out_shape.mutable_layout(), + GetLayoutWithNewMinorMostDimension(dot->shape().layout())); + + HloInstruction* new_dot = + computation->AddInstruction(HloInstruction::CreateDot( + new_out_shape, new_lhs, new_rhs, dot->dot_dimension_numbers(), + dot->precision_config())); + HloInstruction* bitcast = computation->AddInstruction( + HloInstruction::CreateBitcast(dot->shape(), new_dot)); + return computation->ReplaceInstruction(dot, bitcast); + } + + bool changed() const { return changed_; } + + private: + bool changed_ = false; +}; + +} // namespace + +absl::StatusOr GemvRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + GemvRewriterVisitor gemv_rewriter; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + TF_RETURN_IF_ERROR(computation->Accept(&gemv_rewriter)); + } + return gemv_rewriter.changed(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter.h b/third_party/xla/xla/service/gpu/gemv_rewriter.h new file mode 100644 index 00000000000000..a041138b8af5c6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/gemv_rewriter.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_GEMV_REWRITER_H_ +#define XLA_SERVICE_GPU_GEMV_REWRITER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Rewrite a matrix-vector or a vector-matrix multiplication into a +// matrix-matrix multiplication with a trivial dimension. For example, +// [m x n] @ [n] is rewritten to [m x n] @ [n x 1], and [n] @ [m x n] is +// rewritten to [n x 1] @ [m x n]. +class GemvRewriter : public HloModulePass { + public: + absl::string_view name() const override { return "gemv-rewriter"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_GEMV_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc b/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc new file mode 100644 index 00000000000000..46aee0aab3fb88 --- /dev/null +++ b/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc @@ -0,0 +1,135 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gemv_rewriter.h" + +#include +#include + +#include +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +class GemvRewriterTest : public HloTestBase {}; + +TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationToGemm) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[32,7] parameter(0) + p1 = f32[7] parameter(1) + ROOT d = f32[32] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"; + + const char* expected = R"() +// CHECK: %[[P0:.*]] = f32[32,7]{1,0} parameter(0) +// CHECK: %[[P1:.*]] = f32[7]{0} parameter(1) +// CHECK: %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P1]]) +// CHECK: %[[DOT:.*]] = f32[32,1]{1,0} dot(%[[P0]], %[[BITCAST]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} +// CHECK: ROOT %[[ROOT:.*]] = f32[32]{0} bitcast(%[[DOT]]) +})"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected); +} + +TEST_F(GemvRewriterTest, RewriteVectorMatrixMultiplicationToGemm) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[7] parameter(0) + p1 = f32[7,32] parameter(1) + ROOT d = f32[32] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} + })"; + + const char* expected = R"() +// CHECK: %[[P0:.*]] = f32[7]{0} parameter(0) +// CHECK: %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P0]]) +// CHECK: %[[P1:.*]] = f32[7,32]{1,0} parameter(1) +// CHECK: %[[DOT:.*]] = f32[1,32]{1,0} dot(%[[BITCAST]], %[[P1]]), lhs_contracting_dims={0}, rhs_contracting_dims={0} +// CHECK: ROOT %[[ROOT:.*]].1 = f32[32]{0} bitcast(%[[DOT]]) +})"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected); +} + +TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationWithBatch) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[2,5,32,7] parameter(0) + p1 = f32[2,5,7] parameter(1) + ROOT d = f32[2,5,32] dot(p0, p1), + lhs_batch_dims={0,1}, rhs_batch_dims={0,1}, + lhs_contracting_dims={3}, rhs_contracting_dims={2} + })"; + + const char* expected = R"() +// CHECK: %[[P0:.*]] = f32[2,5,32,7]{3,2,1,0} parameter(0) +// CHECK: %[[P1:.*]] = f32[2,5,7]{2,1,0} parameter(1) +// CHECK: %[[BITCAST:.*]] = f32[2,5,7,1]{3,2,1,0} bitcast(%[[P1]]) +// CHECK: %[[DOT:.*]] = f32[2,5,32,1]{3,2,1,0} dot(%[[P0]], %[[BITCAST]]), +// CHECK-SAME: lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} +// CHECK: ROOT %[[ROOT:.*]] = f32[2,5,32]{2,1,0} bitcast(%[[DOT]]) +})"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected); +} + +TEST_F(GemvRewriterTest, DotNotRewriteVectorVectorMultiplication) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[7] parameter(0) + p1 = f32[7] parameter(1) + ROOT d = f32[] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} + })"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt); +} + +TEST_F(GemvRewriterTest, DoNotRewriteDotsWithNonNormalizedLayout) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[5,32,7]{2,1,0} parameter(0) + p1 = f32[5,7]{0,1} parameter(1) + ROOT d = f32[5,32]{0,1} dot(p0, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + GemvRewriter rewriter; + absl::StatusOr result = this->RunHloPass(&rewriter, module.get()); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(result.status().message(), "Layout is not normalized."); +} + +} // namespace +} // namespace xla::gpu