Skip to content

Commit

Permalink
Fix umul-zero and bit-width issues in arith_simp
Browse files Browse the repository at this point in the history
We incorrectly simplified code like X*0 == CONST causing mis-compiles.

PiperOrigin-RevId: 654178127
  • Loading branch information
allight authored and copybara-github committed Jul 20, 2024
1 parent c0487ef commit 9e21b56
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 11 deletions.
3 changes: 3 additions & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,12 @@ cc_test(
"@com_google_absl//absl/time",
"//xls/common:math_util",
"//xls/common:xls_gunit_main",
"//xls/common/fuzzing:fuzztest",
"//xls/common/status:matchers",
"//xls/interpreter:ir_interpreter",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:bits_test_utils",
"//xls/ir:events",
"//xls/ir:function_builder",
"//xls/ir:ir_matcher",
Expand Down Expand Up @@ -462,6 +464,7 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/data_structures:inline_bitmap",
Expand Down
55 changes: 45 additions & 10 deletions xls/passes/arith_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/data_structures/inline_bitmap.h"
Expand Down Expand Up @@ -315,11 +316,14 @@ absl::StatusOr<bool> MatchComparisonOfInjectiveOp(
}
}
Bits solution;
Node* new_op;
if (binary_op->op == Op::kAdd) {
// (X + C_0) cmp C_1 => x cmp C_1 - C_0
new_op = binary_op->operand;
solution =
bits_ops::Sub(compare->constant.bits(), binary_op->constant.bits());
} else if (binary_op->op == Op::kSub) {
new_op = binary_op->operand;
if (binary_op->constant_on_lhs) {
// (C_0 - X) cmp C_1 => x cmp C_0 - C_1
solution =
Expand All @@ -333,28 +337,59 @@ absl::StatusOr<bool> MatchComparisonOfInjectiveOp(
XLS_RET_CHECK_EQ(binary_op->op, Op::kUMul);
// Need to be careful not to break the case where the comparison is actually
// impossible to satisfy or reject (eg 2*x == 3).
if (!compare->constant.bits().IsZero() &&
!bits_ops::UMod(compare->constant.bits(), binary_op->constant.bits())
.IsZero()) {
bool const_is_zero = compare->constant.bits().IsZero();
bool mul_is_zero = binary_op->constant.bits().IsZero();
bool result_is_possible =
bits_ops::UMod(compare->constant.bits(), binary_op->constant.bits())
.IsZero();
if (const_is_zero && mul_is_zero) {
VLOG(2) << "FOUND: Constant umul comparison.";
XLS_RETURN_IF_ERROR(
node->ReplaceUsesWithNew<Literal>(Value::Bool(compare->op == Op::kEq))
.status());
return true;
}
if (const_is_zero) {
solution = Bits(binary_op->operand->BitCountOrDie());
} else if (mul_is_zero || !result_is_possible) {
VLOG(2) << "FOUND: Constant umul comparison.";
XLS_RETURN_IF_ERROR(
node->ReplaceUsesWithNew<Literal>(Value::Bool(compare->op == Op::kNe))
.status());
return true;
} else {
int64_t desired_bits = std::max({binary_op->constant.bits().bit_count(),
compare->constant.bits().bit_count(),
binary_op->operand->BitCountOrDie()});
auto extend_to_bits = [&](const Bits& b) -> Bits {
if (b.bit_count() >= desired_bits) {
return b;
}
return bits_ops::ZeroExtend(b, desired_bits);
};
// (C_0 * X) cmp C_1 => X cmp C_1 / C_0
solution = bits_ops::UDiv(extend_to_bits(compare->constant.bits()),
extend_to_bits(binary_op->constant.bits()));
}
if (binary_op->operand->BitCountOrDie() == solution.bit_count()) {
new_op = binary_op->operand;
} else {
XLS_ASSIGN_OR_RETURN(
new_op,
node->function_base()->MakeNodeWithName<ExtendOp>(
binary_op->operand->loc(), binary_op->operand,
solution.bit_count(), Op::kZeroExt,
absl::StrFormat("%s_extended", binary_op->operand->GetName())));
}
// (C_0 * X) cmp C_1 => X cmp C_1 / C_0
solution = bits_ops::Truncate(
bits_ops::UDiv(compare->constant.bits(), binary_op->constant.bits()),
binary_op->operand->BitCountOrDie());
}
XLS_ASSIGN_OR_RETURN(
Literal * new_literal,
node->function_base()->MakeNode<Literal>(node->loc(), Value(solution)));

VLOG(2) << "FOUND: compairson of injective operation.";
XLS_RETURN_IF_ERROR(node->ReplaceUsesWithNew<CompareOp>(
binary_op->operand, new_literal, compare->op)
.status());
XLS_RETURN_IF_ERROR(
node->ReplaceUsesWithNew<CompareOp>(new_op, new_literal, compare->op)
.status());
return true;
}

Expand Down
64 changes: 63 additions & 1 deletion xls/passes/arith_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
#include <cstdlib>
#include <memory>
#include <string>
#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "xls/common/fuzzing/fuzztest.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
Expand All @@ -30,6 +32,7 @@
#include "xls/common/status/matchers.h"
#include "xls/interpreter/function_interpreter.h"
#include "xls/ir/bits.h"
#include "xls/ir/bits_test_utils.h"
#include "xls/ir/events.h"
#include "xls/ir/function.h"
#include "xls/ir/function_base.h"
Expand All @@ -51,6 +54,7 @@ namespace {

constexpr absl::Duration kProverTimeout = absl::Seconds(10);

using status_testing::IsOk;
using status_testing::IsOkAndHolds;
using ::xls::solvers::z3::ScopedVerifyEquivalence;

Expand Down Expand Up @@ -1696,7 +1700,8 @@ TEST_F(ArithSimplificationPassTest, UMulCompare) {
ScopedVerifyEquivalence sve(f);
ScopedRecordIr sri(p.get());
ASSERT_THAT(Run(p.get()), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(), m::Eq(x.node(), m::Literal(UBits(10, 32))));
EXPECT_THAT(f->return_value(),
m::Eq(m::ZeroExt(x.node()), m::Literal(UBits(10, 50))));
}

TEST_F(ArithSimplificationPassTest, UMulCompareOverflow) {
Expand All @@ -1712,6 +1717,32 @@ TEST_F(ArithSimplificationPassTest, UMulCompareOverflow) {
ASSERT_THAT(Run(p.get()), IsOkAndHolds(false));
}

TEST_F(ArithSimplificationPassTest, UMulCompareZero) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue x = fb.Param("x", p->GetBitsType(3));
// Make sure that x*<foo> == 0 is x == 0
fb.Eq(fb.Literal(UBits(0, 6)), fb.UMul(x, fb.Literal(UBits(5, 3)), 6));
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
ScopedVerifyEquivalence sve(f);
ScopedRecordIr sri(p.get());
ASSERT_THAT(Run(p.get()), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(), m::Eq(x.node(), m::Literal(UBits(0, 3))));
}

TEST_F(ArithSimplificationPassTest, UMulMulAndCompareZero) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue x = fb.Param("x", p->GetBitsType(3));
fb.Eq(fb.Literal(UBits(0, 3)), fb.UMul(x, fb.Literal(UBits(0, 3)), 3));
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
// Just make sure that whatever we create for a vacuously true x*0 == 0 is
// consistent.
ScopedVerifyEquivalence sve(f);
ScopedRecordIr sri(p.get());
ASSERT_THAT(Run(p.get()), IsOk());
}

TEST_F(ArithSimplificationPassTest, UMulCompareImpossible) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Expand All @@ -1725,5 +1756,36 @@ TEST_F(ArithSimplificationPassTest, UMulCompareImpossible) {
EXPECT_THAT(f->return_value(), m::Literal(Value::Bool(false)));
}

void UmulFuzz(const Bits& multiplicand, const Bits& result, int64_t var_width,
bool const_on_right, bool var_on_right) {
VerifiedPackage p("umul_fuzz");
FunctionBuilder fb("umul_fuzz", &p);
BValue eq_const = fb.Literal(result);
BValue mul_const = fb.Literal(multiplicand);
BValue var = fb.Param("param_val", p.GetBitsType(var_width));
BValue mul;
if (var_on_right) {
mul = fb.UMul(mul_const, var, result.bit_count());
} else {
mul = fb.UMul(var, mul_const, result.bit_count());
}
if (const_on_right) {
fb.Eq(mul, eq_const);
} else {
fb.Eq(eq_const, mul);
}
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
ScopedVerifyEquivalence sve(f);
ScopedRecordIr sri(&p);
PassResults results;
ASSERT_THAT(ArithSimplificationPass(kMaxOptLevel)
.Run(&p, OptimizationPassOptions(), &results),
status_testing::IsOk());
}

FUZZ_TEST(ArithSimplificationPassFuzzTest, UmulFuzz)
.WithDomains(ArbitraryBits(16), ArbitraryBits(16), fuzztest::InRange(1, 40),
fuzztest::Arbitrary<bool>(), fuzztest::Arbitrary<bool>());

} // namespace
} // namespace xls

0 comments on commit 9e21b56

Please sign in to comment.