From 467bb68c4d2a1d1776c7254666a93bba90156bf1 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Fri, 27 Sep 2024 22:02:03 +0200 Subject: [PATCH] Allow compare/select on int4 data --- xla/service/hlo_verifier.cc | 8 ++++++-- xla/service/hlo_verifier_test.cc | 15 +++++++++++++++ xla/service/layout_normalization.cc | 6 +++++- xla/service/layout_normalization_test.cc | 16 ++++++++++++++++ xla/service/shape_inference.cc | 4 ++++ xla/service/shape_inference_test.cc | 12 ++++++++++++ xla/shape_util.cc | 3 +++ xla/shape_util_test.cc | 6 ++++++ 8 files changed, 67 insertions(+), 3 deletions(-) diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 6578b4ff765ec..b3205232fafd9 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -2911,8 +2911,12 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { const Layout& operand_layout = operand_shape.layout(); Layout::Equal equal_predicate = Layout::Equal().IgnoreTiles().IgnoreMemorySpace(); - if (instruction->opcode() == HloOpcode::kConvert) { - // Convert instructions can change element_size_in_bits + if (instruction->opcode() == HloOpcode::kConvert || + instruction->opcode() == HloOpcode::kCompare || + (instruction->opcode() == HloOpcode::kSelect && + operand_shape.element_type() == PRED)) { + // Convert and Compare instructions can change element_size_in_bits + // Select instructions ignore element_size_in_bits for predicate equal_predicate.IgnoreElementSize(); } else if (instruction->opcode() == HloOpcode::kDynamicSlice || instruction->opcode() == HloOpcode::kDynamicUpdateSlice || diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index 1737bea0eca27..cfa7a48eaf1d0 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -3488,5 +3488,20 @@ TEST_F(HloVerifierTest, NoErrorOnDuplicateChannelId) { ASSERT_IS_OK(verifier.Run(module.get()).status()); } +TEST_F(HloVerifierTestLayoutSensitive, Int4CompareSelect) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main { + a = s4[10]{0:E(4)} parameter(0) + b = s4[10]{0:E(4)} parameter(1) + less = pred[10] compare(a, b), direction=LT + ROOT result = select(less, a, b) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + } // namespace } // namespace xla diff --git a/xla/service/layout_normalization.cc b/xla/service/layout_normalization.cc index 16781509e22c6..74100a62e2011 100644 --- a/xla/service/layout_normalization.cc +++ b/xla/service/layout_normalization.cc @@ -347,7 +347,11 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { auto s = hlo->shape(); auto a = hlo->mutable_operand(0); auto b = hlo->mutable_operand(1); - TF_RET_CHECK(a->shape().layout() == s.layout()); + auto layout_equal = Layout::Equal(); + if (hlo->opcode() == HloOpcode::kCompare) { + layout_equal.IgnoreElementSize(); + } + TF_RET_CHECK(layout_equal(a->shape().layout(), s.layout())); TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(a)); TF_ASSIGN_OR_RETURN(auto b0, GetNormalizedInput(b)); diff --git a/xla/service/layout_normalization_test.cc b/xla/service/layout_normalization_test.cc index 88ea4828ec597..6fcf848ea46be 100644 --- a/xla/service/layout_normalization_test.cc +++ b/xla/service/layout_normalization_test.cc @@ -922,5 +922,21 @@ ENTRY main.17 { }); } +TEST_F(LayoutNormalizationTest, CompareInt4) { + const char* hlo = R"( +HloModule module + +ENTRY main { + a = s4[10]{0:E(4)} parameter(0) + b = s4[10]{0:E(4)} parameter(1) + ROOT out = compare(a, b), direction=EQ +} +)"; + + CheckLayoutNormalization(hlo, R"( +// CHECK: pred[10]{0} compare({{.*}}) +)"); +} + } // namespace } // namespace xla diff --git a/xla/service/shape_inference.cc b/xla/service/shape_inference.cc index 4271cc897f41d..cbcbf97e2e4f8 100644 --- a/xla/service/shape_inference.cc +++ b/xla/service/shape_inference.cc @@ -3755,6 +3755,10 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { on_false.is_dynamic_dimension(dimension)); } } + if (result.has_layout()) { + result.mutable_layout()->set_element_size_in_bits( + on_true.layout().element_size_in_bits()); + } return std::move(result); } diff --git a/xla/service/shape_inference_test.cc b/xla/service/shape_inference_test.cc index 29ae32add358e..6c2cf78ab0245 100644 --- a/xla/service/shape_inference_test.cc +++ b/xla/service/shape_inference_test.cc @@ -239,6 +239,18 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { HasSubstr("Expected array argument for select pred")); } +TEST_F(ShapeInferenceTest, SelectPreservesElementSize) { + Shape pred_shape = ShapeUtil::MakeShape(PRED, {10}); + Shape int4_shape = ShapeUtil::MakeShape(S4, {10}); + int4_shape.mutable_layout()->set_element_size_in_bits(4); + + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, pred_shape, + int4_shape, int4_shape); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape, int4_shape)); +} + TEST_F(ShapeInferenceTest, ClampAllMatrix) { const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_, diff --git a/xla/shape_util.cc b/xla/shape_util.cc index 01f7cacfc9b44..9def58503a085 100644 --- a/xla/shape_util.cc +++ b/xla/shape_util.cc @@ -1067,6 +1067,9 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } else { Shape new_shape = original; new_shape.set_element_type(type); + if (new_shape.has_layout() && type == PRED) { + new_shape.mutable_layout()->set_element_size_in_bits(0); + } return new_shape; } } diff --git a/xla/shape_util_test.cc b/xla/shape_util_test.cc index e239a96ce6aa0..2ed5060456988 100644 --- a/xla/shape_util_test.cc +++ b/xla/shape_util_test.cc @@ -1224,6 +1224,12 @@ TEST(ShapeUtilTest, Int4ShapeSize) { layout->set_element_size_in_bits(4); EXPECT_EQ(ShapeUtil::ArrayDataSize(int4_shape2), 9216 * 6144 / 2); EXPECT_EQ(ShapeUtil::ArraySize(int4_shape2), 9216 * 6144 / 2); + + // Changing the type to PRED should clear element_size_in_bits. + Shape pred_shape = ShapeUtil::ChangeElementType(int4_shape, PRED); + EXPECT_EQ(pred_shape.layout().element_size_in_bits(), 0); + Shape u4_shape = ShapeUtil::ChangeElementType(int4_shape, U4); + EXPECT_EQ(u4_shape.layout().element_size_in_bits(), 4); } TEST(XlaShapeUtilTest, ZeroSize) {