From d1b26f2d3be99b9b16fb64a1d861482781a4afa2 Mon Sep 17 00:00:00 2001 From: Alex Light Date: Fri, 13 Sep 2024 16:35:19 -0700 Subject: [PATCH] Limit path length in BDD-query engine ImpliedNodeTernary This operation could potentially create exceptionally long path lengths which could cause timeouts or OOMs. PiperOrigin-RevId: 674475069 --- xls/passes/bdd_query_engine.cc | 3 ++ xls/passes/bdd_query_engine_test.cc | 48 +++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/xls/passes/bdd_query_engine.cc b/xls/passes/bdd_query_engine.cc index b174777a04..93cf8484fc 100644 --- a/xls/passes/bdd_query_engine.cc +++ b/xls/passes/bdd_query_engine.cc @@ -191,6 +191,9 @@ std::optional BddQueryEngine::ImpliedNodeTernary( conjuction_bit = conjunction_value ? *conjuction_bit : bdd().Not(*conjuction_bit); bdd_predicate_bit = bdd().And(bdd_predicate_bit, *conjuction_bit); + if (ExceedsPathLimit(bdd_predicate_bit)) { + return std::nullopt; + } } // If the predicate evaluates to false, we can't determine // what node value it implies. That is, !predicate || node_bit diff --git a/xls/passes/bdd_query_engine_test.cc b/xls/passes/bdd_query_engine_test.cc index bb7af4a20a..e87b46140c 100644 --- a/xls/passes/bdd_query_engine_test.cc +++ b/xls/passes/bdd_query_engine_test.cc @@ -14,6 +14,11 @@ #include "xls/passes/bdd_query_engine.h" +#include +#include +#include +#include + #include "gmock/gmock.h" #include "gtest/gtest.h" #include "xls/common/status/matchers.h" @@ -289,5 +294,48 @@ TEST_F(BddQueryEngineTest, BitValuesImplyNodeValuePredicateAlwaysFalse) { EXPECT_FALSE(result.has_value()); } +TEST_F(BddQueryEngineTest, ImpliedNodeTernaryChecksPathLength) { + auto p = CreatePackage(); + // Function found with fuzzing and would OOM the process when BDD query engine + // is called with specific arguments. + XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction( + R"( +fn __sample__main(x0: bits[16], x1: bits[8], x2: bits[10], x3: bits[1]) -> (bits[13], bits[16], bits[16][6], bits[16], bits[16]) { + literal.73: bits[6] = literal(value=0, id=73) + concat.74: bits[16] = concat(literal.73, x2, id=74) + x4: bits[16] = and(concat.74, x0, id=6, pos=[(0,4,32)]) + bit_slice.97: bits[10] = bit_slice(x4, start=0, width=10, id=97, pos=[(0,20,25)]) + literal.100: bits[10] = literal(value=2, id=100, pos=[(0,20,25)]) + x22: bits[16] = neg(x0, id=39, pos=[(0,21,23)]) + ugt.99: bits[1] = ugt(bit_slice.97, literal.100, id=99, pos=[(0,20,25)]) + literal.36: bits[16] = literal(value=2, id=36, pos=[(0,20,44)]) + bit_slice.49: bits[13] = bit_slice(x4, start=0, width=13, id=49) + x23: bits[13] = bit_slice(x22, start=3, width=13, id=88, pos=[(0,22,26)]) + sel.78: bits[16] = sel(ugt.99, cases=[x4, literal.36], id=78, pos=[(0,20,25)]) + x28: bits[13] = and(bit_slice.49, x23, id=50, pos=[(0,26,33)]) + x12: bits[16][6] = array(x0, x4, x0, x0, x4, x0, id=92, pos=[(0,11,28)]) + x17: bits[16] = literal(value=0, id=9, pos=[(0,5,47)]) + x21: bits[16] = sel(sel.78, cases=[x0, x4], default=x0, id=96, pos=[(0,20,24)]) + ret tuple.51: (bits[13], bits[16], bits[16][6], bits[16], bits[16]) = tuple(x28, x0, x12, x17, x21, id=51, pos=[(0,27,8)]) +} +)", + p.get())); + BddQueryEngine query_engine(/*path_limit=*/1024); + XLS_ASSERT_OK(query_engine.Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN(Node * sel_node, f->GetNode("sel.78")); + // NB The specific target is irrelevant. + XLS_ASSERT_OK_AND_ASSIGN(Node * target, f->GetNode("x4")); + std::vector> vals; + // Set sel-node to 0x1. + vals.reserve(sel_node->BitCountOrDie()); + vals.push_back({TreeBitLocation(sel_node, 0), true}); + for (int64_t i = 1; i < sel_node->BitCountOrDie(); ++i) { + vals.push_back({TreeBitLocation(sel_node, i), false}); + } + // This will blow up the path depth. + EXPECT_EQ(query_engine.ImpliedNodeTernary(vals, target), std::nullopt) + << "Expected failure to find result due to path-size explosion."; +} + } // namespace } // namespace xls