Skip to content

Commit

Permalink
Array simplification shatter
Browse files Browse the repository at this point in the history
Rewrite small array updates as selects of the previous value and the new value. This can lead to significantly better codegen for some targets.

PiperOrigin-RevId: 677978820
  • Loading branch information
allight authored and copybara-github committed Sep 23, 2024
1 parent 4bb43f7 commit 4994f90
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 8 deletions.
2 changes: 2 additions & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
],
Expand Down Expand Up @@ -2682,6 +2683,7 @@ cc_test(
"//xls/ir:ir_matcher",
"//xls/ir:ir_parser",
"//xls/ir:ir_test_base",
"//xls/ir:source_location",
"//xls/ir:value",
"//xls/solvers:z3_ir_equivalence_testutils",
"@com_google_absl//absl/status:statusor",
Expand Down
86 changes: 82 additions & 4 deletions xls/passes/array_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <deque>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <variant>
#include <vector>
Expand All @@ -27,12 +29,14 @@
#include "absl/container/flat_hash_set.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/ir/bits.h"
#include "xls/ir/bits_ops.h"
#include "xls/ir/function_base.h"
#include "xls/ir/node.h"
#include "xls/ir/node_util.h"
#include "xls/ir/nodes.h"
Expand All @@ -51,6 +55,13 @@
namespace xls {
namespace {

// How small an array needs to be before we just eliminate it in favor of the
// better dependency analysis provided by selects etc. This turns an
// array-update into a select of the old array-index and the new value and
// indexes with a select on the various indexes. In many cases this will allow
// us to entierly remove the array.
constexpr int64_t kSmallArrayLimit = 3;

// Returns true if the given index value is definitely out of bounds for the
// given array type.
bool IndexIsDefinitelyOutOfBounds(Node* index, ArrayType* array_type,
Expand Down Expand Up @@ -177,9 +188,15 @@ struct SimplifyResult {
}
};

bool IsSmallArray(Node* node) {
return node->GetType()->IsArray() &&
node->GetType()->AsArrayOrDie()->size() <= kSmallArrayLimit;
}

// Try to simplify the given array index operation.
absl::StatusOr<SimplifyResult> SimplifyArrayIndex(
ArrayIndex* array_index, const QueryEngine& query_engine) {
ArrayIndex* array_index, const QueryEngine& query_engine,
int64_t opt_level) {
// An array index with a nil index (no index operands) can be replaced by the
// array operand:
//
Expand Down Expand Up @@ -369,11 +386,15 @@ absl::StatusOr<SimplifyResult> SimplifyArrayIndex(
// Only perform this optimization if the array_index is the only user.
// Otherwise the array index(es) are duplicated which can outweigh the benefit
// of selecting the smaller element.
//
// For very small arrays (when narrowing is enabled) we will perform this
// unconditionally since we totally remove the array in these circumstances.
// TODO(meheff): Consider cases where selects with multiple users are still
// advantageous to transform.
if ((array_index->array()->Is<Select>() ||
array_index->array()->Is<PrioritySelect>()) &&
HasSingleUse(array_index->array())) {
(HasSingleUse(array_index->array()) ||
(IsSmallArray(array_index->array()) && SplitsEnabled(opt_level)))) {
VLOG(2) << absl::StrFormat(
"Replacing array-index of select with select of array-indexes: %s",
array_index->ToString());
Expand Down Expand Up @@ -510,6 +531,61 @@ absl::StatusOr<SimplifyResult> SimplifyArrayIndex(
return SimplifyResult::Changed({array_index});
}

// If the array is really small we unconditionally unwind the array-update to
// give other transforms a chance to fully remove the array.
//
// We only do this if they both are 1-long indexes to avoid having to recreate
// the other array elements.
//
// (array-index (array-update C V {update-idx}) {arr-idx})
//
// Transforms to:
//
// (let ((effective-read-idx (if (> arr-idx (array-len C))
// (- (array-len C) 1)
// arr-idx)))
// (if (= effective-read-idx update-idx)
// V
// (array-index C arr-idx)))
//
// NB Since in the case where an update did not actually change the array the
// update-idx must be out-of-bounds that means the only way the bounded
// arr-idx can match the update-idx is if the update-idx is in bounds and so
// the update acctually happened.
if (IsSmallArray(array_update) && array_update->indices().size() == 1 &&
array_index->indices().size() == 1 && SplitsEnabled(opt_level)) {
VLOG(2) << "Replacing " << array_index << " with select using "
<< array_update << " context";
FunctionBase* fb = array_update->function_base();
auto name_fmt = [&](Node* src, std::string_view postfix) -> std::string {
if (src->HasAssignedName()) {
return absl::StrCat(src->GetNameView(), postfix);
}
return "";
};
int64_t array_bounds = array_update->GetType()->AsArrayOrDie()->size() - 1;
XLS_ASSIGN_OR_RETURN(Node * bounded_arr_idx,
UnsignedUpperBoundLiteral(
array_index->indices().front(), array_bounds));
XLS_ASSIGN_OR_RETURN(
Node * index_is_updated_value,
CompareNumeric(array_update->indices().front(), bounded_arr_idx,
Op::kEq, name_fmt(array_index, "_is_updated_value")));
XLS_ASSIGN_OR_RETURN(
Node * old_value_get,
fb->MakeNodeWithName<ArrayIndex>(
array_index->loc(), array_update->array_to_update(),
array_index->indices(), name_fmt(array_index, "_former_value")));
XLS_RETURN_IF_ERROR(
array_index
->ReplaceUsesWithNew<PrioritySelect>(
index_is_updated_value,
absl::MakeConstSpan({array_update->update_value()}),
old_value_get)
.status());
return SimplifyResult::Changed({old_value_get});
}

return SimplifyResult::Unchanged();
}

Expand Down Expand Up @@ -1380,6 +1456,8 @@ absl::StatusOr<bool> ArraySimplificationPass::RunOnFunctionBaseInternal(
}

while (!worklist.empty()) {
VLOG(2) << "Worklist is " << worklist.size() << "/" << func->node_count()
<< " nodes for " << func->name();
Node* node = remove_from_worklist();

if (node->IsDead()) {
Expand All @@ -1389,8 +1467,8 @@ absl::StatusOr<bool> ArraySimplificationPass::RunOnFunctionBaseInternal(
SimplifyResult result = {.changed = false, .new_worklist_nodes = {}};
if (node->Is<ArrayIndex>()) {
ArrayIndex* array_index = node->As<ArrayIndex>();
XLS_ASSIGN_OR_RETURN(result,
SimplifyArrayIndex(array_index, query_engine));
XLS_ASSIGN_OR_RETURN(
result, SimplifyArrayIndex(array_index, query_engine, opt_level_));
} else if (node->Is<ArrayUpdate>()) {
XLS_ASSIGN_OR_RETURN(
result, SimplifyArrayUpdate(node->As<ArrayUpdate>(), query_engine,
Expand Down
71 changes: 67 additions & 4 deletions xls/passes/array_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "xls/ir/ir_parser.h"
#include "xls/ir/ir_test_base.h"
#include "xls/ir/package.h"
#include "xls/ir/source_location.h"
#include "xls/ir/value.h"
#include "xls/passes/constant_folding_pass.h"
#include "xls/passes/dce_pass.h"
Expand All @@ -39,6 +40,9 @@ namespace xls {
namespace {

using status_testing::IsOkAndHolds;
using ::testing::Each;
using ::testing::Not;
using ::xls::solvers::z3::ScopedVerifyEquivalence;

class ArraySimplificationPassTest : public IrTestBase {
protected:
Expand Down Expand Up @@ -269,8 +273,9 @@ TEST_F(ArraySimplificationPassTest, IndexingArrayUpdateOperationUnknownIndex) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u32 = p->GetBitsType(32);
BValue a = fb.Array(
{fb.Param("x", u32), fb.Param("y", u32), fb.Param("z", u32)}, u32);
BValue a = fb.Array({fb.Param("w", u32), fb.Param("x", u32),
fb.Param("y", u32), fb.Param("z", u32)},
u32);
BValue update_index = fb.Param("idx", u32);
BValue array_update = fb.ArrayUpdate(a, fb.Param("q", u32), {update_index});
BValue index = fb.Literal(Value(UBits(1, 16)));
Expand All @@ -287,8 +292,9 @@ TEST_F(ArraySimplificationPassTest,
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u32 = p->GetBitsType(32);
BValue a = fb.Array(
{fb.Param("x", u32), fb.Param("y", u32), fb.Param("z", u32)}, u32);
BValue a = fb.Array({fb.Param("w", u32), fb.Param("x", u32),
fb.Param("y", u32), fb.Param("z", u32)},
u32);
BValue index = fb.Param("idx", u32);
BValue array_update = fb.ArrayUpdate(a, fb.Param("q", u32), {index});
fb.ArrayIndex(array_update, {index});
Expand Down Expand Up @@ -1350,5 +1356,62 @@ TEST_F(ArraySimplificationPassTest, IndexingArrayConcatNonConstant) {
m::ArrayIndex(m::ArrayConcat(), {m::Param(), m::Param()}));
}

TEST_F(ArraySimplificationPassTest, BasicRemoval) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue arr = fb.Array(
{fb.Param("v1", p->GetBitsType(32)), fb.Param("v2", p->GetBitsType(32))},
p->GetBitsType(32));
fb.ArrayIndex(arr, {fb.Param("idx", p->GetBitsType(32))});

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
ScopedVerifyEquivalence sve(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(),
m::Select(m::Param("idx"), {m::Param("v1")}, m::Param("v2")));
}

TEST_F(ArraySimplificationPassTest, RemovalOfUpdateIndexLiteral) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue arr = fb.Array(
{fb.Param("v1", p->GetBitsType(32)), fb.Param("v2", p->GetBitsType(32))},
p->GetBitsType(32));
BValue updated = fb.ArrayUpdate(arr, fb.Param("v3", p->GetBitsType(32)),
{fb.Param("idx", p->GetBitsType(32))});
fb.ArrayIndex(updated, {fb.Literal(UBits(0, 32))});

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
ScopedVerifyEquivalence sve(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(),
m::PrioritySelect(testing::A<const Node*>(), {m::Param("v3")},
m::Param("v1")));
EXPECT_THAT(f->nodes(), Each(Not(m::Array())));
}

TEST_F(ArraySimplificationPassTest, RemovalOfUpdate) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue arr = fb.Array(
{fb.Param("v1", p->GetBitsType(32)), fb.Param("v2", p->GetBitsType(32))},
p->GetBitsType(32), SourceInfo(), "arr");
BValue updated = fb.ArrayUpdate(arr, fb.Param("v3", p->GetBitsType(32)),
{fb.Param("idx", p->GetBitsType(32))},
SourceInfo(), "updated_arr");
fb.ArrayIndex(updated, {fb.Param("idx2", p->GetBitsType(32))}, SourceInfo(),
"arr_read");

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
ScopedVerifyEquivalence sve(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));
auto original_value =
m::Select(m::Param("idx2"), {m::Param("v1")}, m::Param("v2"));
EXPECT_THAT(f->return_value(),
m::PrioritySelect(testing::A<const Node*>(), {m::Param("v3")},
original_value));
EXPECT_THAT(f->nodes(), Each(Not(m::Array())));
}

} // namespace
} // namespace xls

0 comments on commit 4994f90

Please sign in to comment.