Skip to content

Commit

Permalink
Allow compare/select on int4 data
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-kozub committed Sep 27, 2024
1 parent b7dbeb6 commit 493462a
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 3 deletions.
8 changes: 6 additions & 2 deletions xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2911,8 +2911,12 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
const Layout& operand_layout = operand_shape.layout();
Layout::Equal equal_predicate =
Layout::Equal().IgnoreTiles().IgnoreMemorySpace();
if (instruction->opcode() == HloOpcode::kConvert) {
// Convert instructions can change element_size_in_bits
if (instruction->opcode() == HloOpcode::kConvert ||
instruction->opcode() == HloOpcode::kCompare ||
(instruction->opcode() == HloOpcode::kSelect &&
operand_shape.element_type() == PRED)) {
// Convert and Compare instructions can change element_size_in_bits
// Select instructions ignore element_size_in_bits for predicate
equal_predicate.IgnoreElementSize();
} else if (instruction->opcode() == HloOpcode::kDynamicSlice ||
instruction->opcode() == HloOpcode::kDynamicUpdateSlice ||
Expand Down
15 changes: 15 additions & 0 deletions xla/service/hlo_verifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3488,5 +3488,20 @@ TEST_F(HloVerifierTest, NoErrorOnDuplicateChannelId) {
ASSERT_IS_OK(verifier.Run(module.get()).status());
}

TEST_F(HloVerifierTestLayoutSensitive, Int4CompareSelect) {
const char* const kModuleStr = R"(
HloModule test
ENTRY main {
a = s4[10]{0:E(4)} parameter(0)
b = s4[10]{0:E(4)} parameter(1)
less = pred[10] compare(a, b), direction=LT
ROOT result = select(less, a, b)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
TF_ASSERT_OK(verifier().Run(module.get()));
}

} // namespace
} // namespace xla
6 changes: 5 additions & 1 deletion xla/service/layout_normalization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,11 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor {
auto s = hlo->shape();
auto a = hlo->mutable_operand(0);
auto b = hlo->mutable_operand(1);
TF_RET_CHECK(a->shape().layout() == s.layout());
auto layout_equal = Layout::Equal();
if (hlo->opcode() == HloOpcode::kCompare) {
layout_equal.IgnoreElementSize();
}
TF_RET_CHECK(layout_equal(a->shape().layout(), s.layout()));
TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(a));
TF_ASSIGN_OR_RETURN(auto b0, GetNormalizedInput(b));

Expand Down
16 changes: 16 additions & 0 deletions xla/service/layout_normalization_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -922,5 +922,21 @@ ENTRY main.17 {
});
}

TEST_F(LayoutNormalizationTest, CompareInt4) {
const char* hlo = R"(
HloModule module
ENTRY main {
a = s4[10]{0:E(4)} parameter(0)
b = s4[10]{0:E(4)} parameter(1)
ROOT out = compare(a, b), direction=EQ
}
)";

CheckLayoutNormalization(hlo, R"(
// CHECK: pred[10]{0} compare({{.*}})
)");
}

} // namespace
} // namespace xla
4 changes: 4 additions & 0 deletions xla/service/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3755,6 +3755,10 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) {
on_false.is_dynamic_dimension(dimension));
}
}
if (result.has_layout()) {
result.mutable_layout()->set_element_size_in_bits(
on_true.layout().element_size_in_bits());
}
return std::move(result);
}

Expand Down
12 changes: 12 additions & 0 deletions xla/service/shape_inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,18 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) {
HasSubstr("Expected array argument for select pred"));
}

TEST_F(ShapeInferenceTest, SelectPreservesElementSize) {
Shape pred_shape = ShapeUtil::MakeShape(PRED, {10});
Shape int4_shape = ShapeUtil::MakeShape(S4, {10});
int4_shape.mutable_layout()->set_element_size_in_bits(4);

const absl::StatusOr<Shape> inferred_shape =
ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, pred_shape,
int4_shape, int4_shape);
ASSERT_IS_OK(inferred_shape.status());
ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape, int4_shape));
}

TEST_F(ShapeInferenceTest, ClampAllMatrix) {
const absl::StatusOr<Shape> inferred_shape =
ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_,
Expand Down

0 comments on commit 493462a

Please sign in to comment.