From 493462aab448502dd67a999082809f5e30526999 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Fri, 27 Sep 2024 22:02:03 +0200 Subject: [PATCH] Allow compare/select on int4 data --- xla/service/hlo_verifier.cc | 8 ++++++-- xla/service/hlo_verifier_test.cc | 15 +++++++++++++++ xla/service/layout_normalization.cc | 6 +++++- xla/service/layout_normalization_test.cc | 16 ++++++++++++++++ xla/service/shape_inference.cc | 4 ++++ xla/service/shape_inference_test.cc | 12 ++++++++++++ 6 files changed, 58 insertions(+), 3 deletions(-) diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 6578b4ff765ec9..b3205232fafd98 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -2911,8 +2911,12 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { const Layout& operand_layout = operand_shape.layout(); Layout::Equal equal_predicate = Layout::Equal().IgnoreTiles().IgnoreMemorySpace(); - if (instruction->opcode() == HloOpcode::kConvert) { - // Convert instructions can change element_size_in_bits + if (instruction->opcode() == HloOpcode::kConvert || + instruction->opcode() == HloOpcode::kCompare || + (instruction->opcode() == HloOpcode::kSelect && + operand_shape.element_type() == PRED)) { + // Convert and Compare instructions can change element_size_in_bits + // Select instructions ignore element_size_in_bits for predicate equal_predicate.IgnoreElementSize(); } else if (instruction->opcode() == HloOpcode::kDynamicSlice || instruction->opcode() == HloOpcode::kDynamicUpdateSlice || diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index 1737bea0eca27b..cfa7a48eaf1d0e 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -3488,5 +3488,20 @@ TEST_F(HloVerifierTest, NoErrorOnDuplicateChannelId) { ASSERT_IS_OK(verifier.Run(module.get()).status()); } +TEST_F(HloVerifierTestLayoutSensitive, Int4CompareSelect) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main { + a = s4[10]{0:E(4)} parameter(0) + b = s4[10]{0:E(4)} parameter(1) + less = pred[10] compare(a, b), direction=LT + ROOT result = select(less, a, b) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + } // namespace } // namespace xla diff --git a/xla/service/layout_normalization.cc b/xla/service/layout_normalization.cc index 16781509e22c60..74100a62e20111 100644 --- a/xla/service/layout_normalization.cc +++ b/xla/service/layout_normalization.cc @@ -347,7 +347,11 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { auto s = hlo->shape(); auto a = hlo->mutable_operand(0); auto b = hlo->mutable_operand(1); - TF_RET_CHECK(a->shape().layout() == s.layout()); + auto layout_equal = Layout::Equal(); + if (hlo->opcode() == HloOpcode::kCompare) { + layout_equal.IgnoreElementSize(); + } + TF_RET_CHECK(layout_equal(a->shape().layout(), s.layout())); TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(a)); TF_ASSIGN_OR_RETURN(auto b0, GetNormalizedInput(b)); diff --git a/xla/service/layout_normalization_test.cc b/xla/service/layout_normalization_test.cc index 88ea4828ec597a..6fcf848ea46be8 100644 --- a/xla/service/layout_normalization_test.cc +++ b/xla/service/layout_normalization_test.cc @@ -922,5 +922,21 @@ ENTRY main.17 { }); } +TEST_F(LayoutNormalizationTest, CompareInt4) { + const char* hlo = R"( +HloModule module + +ENTRY main { + a = s4[10]{0:E(4)} parameter(0) + b = s4[10]{0:E(4)} parameter(1) + ROOT out = compare(a, b), direction=EQ +} +)"; + + CheckLayoutNormalization(hlo, R"( +// CHECK: pred[10]{0} compare({{.*}}) +)"); +} + } // namespace } // namespace xla diff --git a/xla/service/shape_inference.cc b/xla/service/shape_inference.cc index 4271cc897f41d7..cbcbf97e2e4f88 100644 --- a/xla/service/shape_inference.cc +++ b/xla/service/shape_inference.cc @@ -3755,6 +3755,10 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { on_false.is_dynamic_dimension(dimension)); } } + if (result.has_layout()) { + result.mutable_layout()->set_element_size_in_bits( + on_true.layout().element_size_in_bits()); + } return std::move(result); } diff --git a/xla/service/shape_inference_test.cc b/xla/service/shape_inference_test.cc index 29ae32add358e3..6c2cf78ab0245c 100644 --- a/xla/service/shape_inference_test.cc +++ b/xla/service/shape_inference_test.cc @@ -239,6 +239,18 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { HasSubstr("Expected array argument for select pred")); } +TEST_F(ShapeInferenceTest, SelectPreservesElementSize) { + Shape pred_shape = ShapeUtil::MakeShape(PRED, {10}); + Shape int4_shape = ShapeUtil::MakeShape(S4, {10}); + int4_shape.mutable_layout()->set_element_size_in_bits(4); + + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, pred_shape, + int4_shape, int4_shape); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape, int4_shape)); +} + TEST_F(ShapeInferenceTest, ClampAllMatrix) { const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_,