Skip to content

Commit

Permalink
[Metal] Dispatch numerically stable tanh for metal (#16438)
Browse files Browse the repository at this point in the history
Prior to this PR, `tanh(x)`returns `NaN` on metal when `x > 45.0`.

Metal's built-in tanh is implemented as `(t - 1.0) / (t + 1.0)`, where `t = exp(2.0 * x)`. Hence for large `x`, `t` becomes `inf`, causing `tanh(x)` to be `NaN`.

A numerically stable `tanh` is implemented for `llvm`, this PR lifts it to `src/target/intrin_rule.cc` and apply the same rule for metal as well.
  • Loading branch information
CharlieFRuan authored Jan 20, 2024
1 parent 28c68e8 commit ffa404f
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 21 deletions.
18 changes: 18 additions & 0 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,24 @@ PrimExpr DispatchFastErf(const PrimExpr& e) {
return res;
}

PrimExpr DispatchNumericalStableTanh(const PrimExpr& e) {
using tir::make_const;
using tir::make_zero;
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr one = make_const(x.dtype(), 1);
PrimExpr two = make_const(x.dtype(), 2);
PrimExpr neg_two = make_const(x.dtype(), -2);

PrimExpr exp_neg2x = exp(neg_two * x);
PrimExpr exp_pos2x = exp(two * x);

PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
}

} // namespace intrin

namespace legalize {
Expand Down
3 changes: 3 additions & 0 deletions src/target/intrin_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) {
// Dispatch ERF to fast erf when it is not available.
PrimExpr DispatchFastErf(const PrimExpr& e);

// Dispatch numerically stable tanh such that tanh(large_num) does not result in NaN
PrimExpr DispatchNumericalStableTanh(const PrimExpr& e);

} // namespace intrin
} // namespace codegen
} // namespace tvm
Expand Down
25 changes: 6 additions & 19 deletions src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../intrin_rule.h"

namespace tvm {
namespace codegen {
namespace llvm {
Expand Down Expand Up @@ -99,6 +101,10 @@ TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(

TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
"llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);

TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
::tvm::codegen::intrin::DispatchNumericalStableTanh);
} // namespace intrin

namespace legalize {
Expand All @@ -116,25 +122,6 @@ TVM_REGISTER_OP("tir.exp10")
return ret;
});

TVM_REGISTER_OP("tir.tanh")
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using tir::make_zero;
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr one = make_const(x.dtype(), 1);
PrimExpr two = make_const(x.dtype(), 2);
PrimExpr neg_two = make_const(x.dtype(), -2);

PrimExpr exp_neg2x = exp(neg_two * x);
PrimExpr exp_pos2x = exp(two * x);

PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
});

TVM_REGISTER_OP("tir.tan").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/intrin_rule_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ TVM_REGISTER_OP("tir.log10")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);

TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchNumericalStableTanh);

TVM_REGISTER_OP("tir.sqrt")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/intrin_rule_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
DispatchPureExtern<Direct>);

TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchPureExtern<Direct>);
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchNumericalStableTanh);

TVM_REGISTER_OP("tir.trunc")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchPureExtern<Direct>);
Expand Down

0 comments on commit ffa404f

Please sign in to comment.