Skip to content

Commit

Permalink
Move some large stack frames off recursive paths. (#8507)
Browse files Browse the repository at this point in the history
* Remove large stack frames from recursive paths

There were several large (>2kb) stack frames on recursive paths, which
invites stack overflow. In all cases the code could be moved into a
helper function out of the recursive path, eliminating the problem. Some
of these stack frames were also shrunk by removing state from the Intrin
IR matcher, and removing unnecessary precision in the IsInt and IsUInt
IR matchers.

Also added IRMatcher helpers for a few more intrinsics

Note the tables in HexagonOptimize are unchanged, they just got indented
more by being moved into a lambda.

* Update ConstantBounds.cpp

---------

Co-authored-by: Steven Johnson <[email protected]>
  • Loading branch information
abadams and steven-johnson authored Dec 13, 2024
1 parent c3f4de0 commit 5f17d6f
Show file tree
Hide file tree
Showing 8 changed files with 387 additions and 328 deletions.
1 change: 0 additions & 1 deletion src/ConstantBounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ ConstantInterval bounds_helper(const Expr &e,
ScopedBinding bind(scope, op->name, recurse(op->value));
return recurse(op->body);
} else if (const Call *op = e.as<Call>()) {
ConstantInterval result;
if (op->is_intrinsic(Call::abs)) {
return abs(recurse(op->args[0]));
} else if (op->is_intrinsic(Call::absd)) {
Expand Down
20 changes: 16 additions & 4 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,12 @@ class FindIntrinsics : public IRMutator {
return IRMutator::visit(op);
}

Expr value = mutate(op->value);
return visit_cast(op, mutate(op->value));
}

// Isolated in its own function to keep the (large) stack frame off the
// recursive path.
HALIDE_NEVER_INLINE Expr visit_cast(const Cast *op, Expr &&value) {
// This mutator can generate redundant casts. We can't use the simplifier because it
// undoes some of the intrinsic lowering here, and it causes some problems due to
// factoring (instead of distributing) constants.
Expand Down Expand Up @@ -550,6 +554,7 @@ class FindIntrinsics : public IRMutator {
auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits);
auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint;
auto x_y_same_sign = (is_int(x) && is_int(y)) || (is_uint(x) && is_uint(y));

if (
// Saturating patterns
rewrite(max(min(widening_add(x, y), upper), lower),
Expand All @@ -566,11 +571,11 @@ class FindIntrinsics : public IRMutator {

rewrite(min(widening_add(x, y), upper),
saturating_add(x, y),
op->type.is_uint() && is_x_same_uint) ||
is_x_same_uint) ||

rewrite(max(widening_sub(x, y), lower),
saturating_sub(x, y),
op->type.is_uint() && is_x_same_uint) ||
is_x_same_uint) ||

// Saturating narrow patterns.
rewrite(max(min(x, upper), lower),
Expand Down Expand Up @@ -721,10 +726,17 @@ class FindIntrinsics : public IRMutator {
op = mutated.as<Call>();
if (!op) {
return mutated;
} else {
return visit_call(op);
}
}

// Isolated in its own function to keep the (large) stack frame off the
// recursive path. The Call node has already been mutated by the base class
// visitor.
HALIDE_NEVER_INLINE Expr visit_call(const Call *op) {
auto rewrite = IRMatcher::rewriter(op, op->type);
if (rewrite(intrin(Call::abs, widening_sub(x, y)), cast(op->type, intrin(Call::absd, x, y))) ||
if (rewrite(abs(widening_sub(x, y)), cast(op->type, absd(x, y))) ||
false) {
return rewrite.result;
}
Expand Down
512 changes: 264 additions & 248 deletions src/HexagonOptimize.cpp

Large diffs are not rendered by default.

146 changes: 89 additions & 57 deletions src/IRMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1325,15 +1325,30 @@ constexpr int const_min(int a, int b) {
return a < b ? a : b;
}

template<typename... Args>
template<Call::IntrinsicOp intrin>
struct OptionalIntrinType {
bool check(const Type &) const {
return true;
}
};

template<>
struct OptionalIntrinType<Call::saturating_cast> {
halide_type_t type;
bool check(const Type &t) const {
return t == Type(type);
}
};

template<Call::IntrinsicOp intrin, typename... Args>
struct Intrin {
struct pattern_tag {};
Call::IntrinsicOp intrin;
std::tuple<Args...> args;
// The type of the output of the intrinsic node.
// Only necessary in cases where it can't be inferred
// from the input types (e.g. saturating_cast).
Type optional_type_hint;

OptionalIntrinType<intrin> optional_type_hint;

static constexpr uint32_t binds = bitwise_or_reduce((bindings<Args>::mask)...);

Expand Down Expand Up @@ -1362,7 +1377,7 @@ struct Intrin {
}
const Call &c = (const Call &)e;
return (c.is_intrinsic(intrin) &&
((optional_type_hint == Type()) || optional_type_hint == e.type) &&
optional_type_hint.check(e.type) &&
match_args<0, bound>(0, c, state));
}

Expand Down Expand Up @@ -1394,8 +1409,8 @@ struct Intrin {
return likely_if_innermost(std::move(arg0));
} else if (intrin == Call::abs) {
return abs(std::move(arg0));
} else if (intrin == Call::saturating_cast) {
return saturating_cast(optional_type_hint, std::move(arg0));
} else if constexpr (intrin == Call::saturating_cast) {
return saturating_cast(optional_type_hint.type, std::move(arg0));
}

Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
Expand Down Expand Up @@ -1489,98 +1504,113 @@ struct Intrin {
}

HALIDE_ALWAYS_INLINE
Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
: intrin(intrin), args(args...) {
Intrin(Args... args) noexcept
: args(args...) {
}
};

template<typename... Args>
std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
s << op.intrin << "(";
template<Call::IntrinsicOp intrin, typename... Args>
std::ostream &operator<<(std::ostream &s, const Intrin<intrin, Args...> &op) {
s << intrin << "(";
op.print_args(s);
s << ")";
return s;
}

template<typename... Args>
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
return {intrinsic_op, pattern_arg(args)...};
}

template<typename A, typename B>
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widen_right_add, pattern_arg(a), pattern_arg(b)};
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widen_right_mul, pattern_arg(a), pattern_arg(b)};
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widen_right_sub, pattern_arg(a), pattern_arg(b)};
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}

template<typename A, typename B>
auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widening_add, pattern_arg(a), pattern_arg(b)};
auto widening_add(A &&a, B &&b) noexcept -> Intrin<Call::widening_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widening_sub, pattern_arg(a), pattern_arg(b)};
auto widening_sub(A &&a, B &&b) noexcept -> Intrin<Call::widening_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widening_mul, pattern_arg(a), pattern_arg(b)};
auto widening_mul(A &&a, B &&b) noexcept -> Intrin<Call::widening_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::saturating_add, pattern_arg(a), pattern_arg(b)};
auto saturating_add(A &&a, B &&b) noexcept -> Intrin<Call::saturating_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::saturating_sub, pattern_arg(a), pattern_arg(b)};
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<Call::saturating_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A>
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<decltype(pattern_arg(a))> {
Intrin<decltype(pattern_arg(a))> p = {Call::saturating_cast, pattern_arg(a)};
p.optional_type_hint = t;
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<Call::saturating_cast, decltype(pattern_arg(a))> {
Intrin<Call::saturating_cast, decltype(pattern_arg(a))> p = {pattern_arg(a)};
p.optional_type_hint.type = t;
return p;
}
template<typename A, typename B>
auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
auto halving_add(A &&a, B &&b) noexcept -> Intrin<Call::halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
auto halving_sub(A &&a, B &&b) noexcept -> Intrin<Call::halving_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::rounding_halving_add, pattern_arg(a), pattern_arg(b)};
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<Call::rounding_halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
auto shift_left(A &&a, B &&b) noexcept -> Intrin<Call::shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
auto shift_right(A &&a, B &&b) noexcept -> Intrin<Call::shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::rounding_shift_left, pattern_arg(a), pattern_arg(b)};
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::rounding_shift_right, pattern_arg(a), pattern_arg(b)};
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B, typename C>
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
return {Call::mul_shift_right, pattern_arg(a), pattern_arg(b), pattern_arg(c)};
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<Call::mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
}
template<typename A, typename B, typename C>
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
return {Call::rounding_mul_shift_right, pattern_arg(a), pattern_arg(b), pattern_arg(c)};
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<Call::rounding_mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
}

template<typename A>
auto abs(A &&a) noexcept -> Intrin<Call::abs, decltype(pattern_arg(a))> {
return {pattern_arg(a)};
}

template<typename A, typename B>
auto absd(A &&a, B &&b) noexcept -> Intrin<Call::absd, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}

template<typename A>
auto likely(A &&a) noexcept -> Intrin<Call::likely, decltype(pattern_arg(a))> {
return {pattern_arg(a)};
}

template<typename A>
auto likely_if_innermost(A &&a) noexcept -> Intrin<Call::likely_if_innermost, decltype(pattern_arg(a))> {
return {pattern_arg(a)};
}

template<typename A>
Expand Down Expand Up @@ -2425,7 +2455,8 @@ template<typename A>
struct IsInt {
struct pattern_tag {};
A a;
int bits, lanes;
uint8_t bits;
uint16_t lanes;

constexpr static uint32_t binds = bindings<A>::mask;

Expand All @@ -2448,7 +2479,7 @@ struct IsInt {
};

template<typename A>
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0, int lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
HALIDE_ALWAYS_INLINE auto is_int(A &&a, uint8_t bits = 0, uint16_t lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), bits, lanes};
}
Expand All @@ -2470,7 +2501,8 @@ template<typename A>
struct IsUInt {
struct pattern_tag {};
A a;
int bits, lanes;
uint8_t bits;
uint16_t lanes;

constexpr static uint32_t binds = bindings<A>::mask;

Expand All @@ -2493,7 +2525,7 @@ struct IsUInt {
};

template<typename A>
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0, int lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, uint8_t bits = 0, uint16_t lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), bits, lanes};
}
Expand Down
8 changes: 4 additions & 4 deletions src/Simplify_Max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) {
rewrite(max(select(x, w, max(z, y)), z), max(select(x, w, y), z)) ||
rewrite(max(select(x, w, max(z, y)), y), max(select(x, w, z), y)) ||

rewrite(max(intrin(Call::likely, x), x), b) ||
rewrite(max(x, intrin(Call::likely, x)), a) ||
rewrite(max(intrin(Call::likely_if_innermost, x), x), b) ||
rewrite(max(x, intrin(Call::likely_if_innermost, x)), a) ||
rewrite(max(likely(x), x), b) ||
rewrite(max(x, likely(x)), a) ||
rewrite(max(likely_if_innermost(x), x), b) ||
rewrite(max(x, likely_if_innermost(x)), a) ||

(no_overflow(op->type) &&
(rewrite(max(ramp(x, y, lanes), broadcast(z, lanes)), a,
Expand Down
8 changes: 4 additions & 4 deletions src/Simplify_Min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) {
rewrite(min(select(x, w, min(z, y)), y), min(select(x, w, z), y)) ||
rewrite(min(select(x, w, min(z, y)), z), min(select(x, w, y), z)) ||

rewrite(min(intrin(Call::likely, x), x), b) ||
rewrite(min(x, intrin(Call::likely, x)), a) ||
rewrite(min(intrin(Call::likely_if_innermost, x), x), b) ||
rewrite(min(x, intrin(Call::likely_if_innermost, x)), a) ||
rewrite(min(likely(x), x), b) ||
rewrite(min(x, likely(x)), a) ||
rewrite(min(likely_if_innermost(x), x), b) ||
rewrite(min(x, likely_if_innermost(x)), a) ||

(no_overflow(op->type) &&
(rewrite(min(ramp(x, y, lanes), broadcast(z, lanes)), a,
Expand Down
4 changes: 2 additions & 2 deletions src/Simplify_Not.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ Expr Simplify::visit(const Not *op, ExprInfo *info) {
}

if (rewrite(!broadcast(x, c0), broadcast(!x, c0)) ||
rewrite(!intrin(Call::likely, x), intrin(Call::likely, !x)) ||
rewrite(!intrin(Call::likely_if_innermost, x), intrin(Call::likely_if_innermost, !x)) ||
rewrite(!likely(x), likely(!x)) ||
rewrite(!likely_if_innermost(x), likely_if_innermost(!x)) ||
rewrite(!(!x && y), x || !y) ||
rewrite(!(!x || y), x && !y) ||
rewrite(!(x && !y), !x || y) ||
Expand Down
16 changes: 8 additions & 8 deletions src/Simplify_Select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) {

// clang-format off
if (EVAL_IN_LAMBDA
(rewrite(select(IRMatcher::intrin(Call::likely, true), x, y), x) ||
rewrite(select(IRMatcher::intrin(Call::likely, false), x, y), y) ||
rewrite(select(IRMatcher::intrin(Call::likely_if_innermost, true), x, y), x) ||
rewrite(select(IRMatcher::intrin(Call::likely_if_innermost, false), x, y), y) ||
(rewrite(select(IRMatcher::likely(true), x, y), x) ||
rewrite(select(IRMatcher::likely(false), x, y), y) ||
rewrite(select(IRMatcher::likely_if_innermost(true), x, y), x) ||
rewrite(select(IRMatcher::likely_if_innermost(false), x, y), y) ||
rewrite(select(1, x, y), x) ||
rewrite(select(0, x, y), y) ||
rewrite(select(x, y, y), y) ||
rewrite(select(x, intrin(Call::likely, y), y), false_value) ||
rewrite(select(x, y, intrin(Call::likely, y)), true_value) ||
rewrite(select(x, intrin(Call::likely_if_innermost, y), y), false_value) ||
rewrite(select(x, y, intrin(Call::likely_if_innermost, y)), true_value) ||
rewrite(select(x, likely(y), y), false_value) ||
rewrite(select(x, y, likely(y)), true_value) ||
rewrite(select(x, likely_if_innermost(y), y), false_value) ||
rewrite(select(x, y, likely_if_innermost(y)), true_value) ||
false)) {
return rewrite.result;
}
Expand Down

0 comments on commit 5f17d6f

Please sign in to comment.