From ffa404fb48b9445cc3490d343f76442c01aef46d Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Sat, 20 Jan 2024 01:30:43 -0500 Subject: [PATCH] [Metal] Dispatch numerically stable tanh for metal (#16438) Prior to this PR, `tanh(x)`returns `NaN` on metal when `x > 45.0`. Metal's built-in tanh is implemented as `(t - 1.0) / (t + 1.0)`, where `t = exp(2.0 * x)`. Hence for large `x`, `t` becomes `inf`, causing `tanh(x)` to be `NaN`. A numerically stable `tanh` is implemented for `llvm`, this PR lifts it to `src/target/intrin_rule.cc` and apply the same rule for metal as well. --- src/target/intrin_rule.cc | 18 ++++++++++++++++++ src/target/intrin_rule.h | 3 +++ src/target/llvm/intrin_rule_llvm.cc | 25 ++++++------------------- src/target/source/intrin_rule_metal.cc | 2 +- src/target/source/intrin_rule_webgpu.cc | 2 +- 5 files changed, 29 insertions(+), 21 deletions(-) diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 398e24d2510e..d9fc73cb566b 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -134,6 +134,24 @@ PrimExpr DispatchFastErf(const PrimExpr& e) { return res; } +PrimExpr DispatchNumericalStableTanh(const PrimExpr& e) { + using tir::make_const; + using tir::make_zero; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1); + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_two = make_const(x.dtype(), -2); + + PrimExpr exp_neg2x = exp(neg_two * x); + PrimExpr exp_pos2x = exp(two * x); + + PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); + PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); + return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); +} + } // namespace intrin namespace legalize { diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index b7f5881b3a90..2695c43173a0 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -80,6 +80,9 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) { // Dispatch ERF to fast erf when it is not available. PrimExpr DispatchFastErf(const PrimExpr& e); +// Dispatch numerically stable tanh such that tanh(large_num) does not result in NaN +PrimExpr DispatchNumericalStableTanh(const PrimExpr& e); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 9ef494fd2a0b..2730c0a34d63 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -28,6 +28,8 @@ #include #include +#include "../intrin_rule.h" + namespace tvm { namespace codegen { namespace llvm { @@ -99,6 +101,10 @@ TVM_REGISTER_OP("tir.cos").set_attr( TVM_REGISTER_OP("tir.sin").set_attr( "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); + +TVM_REGISTER_OP("tir.tanh") + .set_attr("llvm.FLowerIntrinsic", + ::tvm::codegen::intrin::DispatchNumericalStableTanh); } // namespace intrin namespace legalize { @@ -116,25 +122,6 @@ TVM_REGISTER_OP("tir.exp10") return ret; }); -TVM_REGISTER_OP("tir.tanh") - .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; - using tir::make_zero; - const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr one = make_const(x.dtype(), 1); - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_two = make_const(x.dtype(), -2); - - PrimExpr exp_neg2x = exp(neg_two * x); - PrimExpr exp_pos2x = exp(two * x); - - PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); - PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); - return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); - }); - TVM_REGISTER_OP("tir.tan").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index cc83eb1462c6..50685f6ef269 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -89,7 +89,7 @@ TVM_REGISTER_OP("tir.log10") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.tanh") - .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); + .set_attr("metal.FLowerIntrinsic", DispatchNumericalStableTanh); TVM_REGISTER_OP("tir.sqrt") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); diff --git a/src/target/source/intrin_rule_webgpu.cc b/src/target/source/intrin_rule_webgpu.cc index 81803059fc49..f3e561f71477 100644 --- a/src/target/source/intrin_rule_webgpu.cc +++ b/src/target/source/intrin_rule_webgpu.cc @@ -105,7 +105,7 @@ TVM_REGISTER_OP("tir.tan").set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.tanh") - .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + .set_attr("webgpu.FLowerIntrinsic", DispatchNumericalStableTanh); TVM_REGISTER_OP("tir.trunc") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern);