From e2fd289549b797d3b495469374ee975301a95751 Mon Sep 17 00:00:00 2001 From: Virdhatchani Narayanamoorthy <138196495+VirdhatchaniKN@users.noreply.github.com> Date: Fri, 20 Sep 2024 13:38:02 +0530 Subject: [PATCH] #12829: Cleanup softplus_bw, hardtanh_bw, prod_bw (#12864) * #12829: Cleanup softplus_bw * #12829: Cleanup hardtanh_bw * #12864: Update files * #12829: Cleanup prod_bw * #12864: Restructure ExecuteUnaryBackwardTwoFloatWithDefault --- .../device/unary_backward_op.cpp | 6 +-- .../device/unary_backward_op.hpp | 29 ------------- .../eltwise/unary_backward/unary_backward.hpp | 42 +++++++------------ 3 files changed, 19 insertions(+), 58 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp index 7a4766690d3..f78c7db3610 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp @@ -54,7 +54,7 @@ std::vector ExecuteUnaryBackwardClamp::invoke( // Hardtanh // result: torch.where((input <= min) | (input >= max), 0.0, grad) -std::vector _hardtanh_bw( +std::vector ExecuteUnaryBackwardHardtanh::invoke( const Tensor& grad, const Tensor& input, float min, float max, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor grad_result = ttnn::where( @@ -81,7 +81,7 @@ std::vector ExecuteUnaryBackwardThreshold::invoke( } // Softplus -std::vector _softplus_bw( +std::vector ExecuteUnaryBackwardSoftplus::invoke( const Tensor& grad, const Tensor& input, float beta, float threshold, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor mul_input_beta = ttnn::multiply(input, beta, std::nullopt, output_mem_config); @@ -1466,7 +1466,7 @@ Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_ // Prod // along a single dimension --> result: grad_data * (y / input ) -std::vector _prod_bw( +std::vector ExecuteUnaryBackwardProd::invoke( const Tensor& grad, const Tensor& input, bool all_dimensions, int64_t dim, const std::optional& output_mem_config) { std::vector grad_tensor; auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp index 894523584f4..eb2bb4927ea 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp @@ -13,8 +13,6 @@ namespace ttnn::operations::unary_backward { enum class UnaryBackwardOpType { - HARDTANH_BW, - SOFTPLUS_BW, DIV_BW, RDIV_BW, MULTIGAMMALN_BW, @@ -76,7 +74,6 @@ enum class UnaryBackwardOpType { DEG2RAD_BW, POLYGAMMA_BW, REPEAT_BW, - PROD_BW, }; std::vector _acos_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); @@ -132,9 +129,6 @@ std::vector _floor_bw( const Tensor& grad, const Tensor& input, const st std::vector _round_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); std::vector _log_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); -std::vector _softplus_bw( const Tensor& grad, const Tensor& input, float beta = 1.0, float threshold = 20.0, const std::optional& output_mem_config = std::nullopt); -std::vector _hardtanh_bw( const Tensor& grad, const Tensor& input, float min = -1.0, float max = 1.0, const std::optional& output_mem_config = std::nullopt); - std::vector _add_bw( const Tensor& grad, const Tensor& input, float alpha, const std::optional& output_mem_config = std::nullopt); std::vector _eq_bw( const Tensor& grad, const Tensor& input, float other, const std::optional& output_mem_config = std::nullopt); @@ -148,21 +142,12 @@ std::vector _logiteps_bw( const Tensor& grad, const Tensor& input, float std::vector _rdiv_bw( const Tensor& grad, const Tensor& input, float scalar, string round_mode = "None", const std::optional& output_mem_config = std::nullopt); std::vector _repeat_bw(const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional& output_mem_config); - -std::vector _prod_bw( const Tensor& grad, const Tensor& input, bool all_dimensions = true, int64_t dim = 0, const std::optional& output_mem_config = std::nullopt); Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_config); // OpHandler struct template template struct OpHandler; -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, float min, float max, const std::optional& output_mem_config ) { - return _hardtanh_bw(grad, input, min, max, output_mem_config); - } -}; - template <> struct OpHandler { static std::vector handle( const Tensor& grad, const Tensor& input, float exponent, const std::optional& output_mem_config ) { @@ -555,13 +540,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, float beta, float threshold, const std::optional& output_mem_config ) { - return _softplus_bw(grad, input, beta, threshold, output_mem_config); - } -}; - template <> struct OpHandler { static std::vector handle( const Tensor& grad, const Tensor& input, float scalar, string round_mode, const std::optional& output_mem_config ) { @@ -576,13 +554,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, bool all_dimensions, int64_t dim, const std::optional& output_mem_config ) { - return _prod_bw(grad, input, all_dimensions, dim, output_mem_config); - } -}; - template <> struct OpHandler { static std::vector handle( const Tensor& grad, const Tensor& input, float other, const std::optional& output_mem_config ) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index 072b3582639..7192a80a34c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -66,17 +66,14 @@ struct ExecuteUnaryBackwardWoFloat { }; -template -struct ExecuteUnaryBackwardTwoFloatWithDefault { - static std::vector invoke( - const Tensor &grad_tensor_arg, - const Tensor &input_tensor_arg, - float parameter_a, - float parameter_b, - const std::optional &memory_config = std::nullopt) { - auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config()); - return OpHandler::handle(grad_tensor_arg, input_tensor_arg, parameter_a, parameter_b, output_memory_config); - } +#define DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(op_name) \ +struct ExecuteUnaryBackward##op_name { \ + static std::vector invoke( \ + const Tensor &grad_tensor_arg, \ + const Tensor &input_tensor_arg, \ + float parameter_a, \ + float parameter_b, \ + const std::optional &memory_config = std::nullopt); \ }; template @@ -248,17 +245,13 @@ struct ExecuteUnaryBackwardFill { std::optional input_grad = std::nullopt); }; -template -struct ExecuteUnaryBackwardProdBW { +struct ExecuteUnaryBackwardProd { static std::vector invoke( const Tensor &grad_tensor_arg, const Tensor &input_tensor_arg, bool all_dimensions = true, int64_t dim = 0, - const std::optional &memory_config = std::nullopt) { - auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config()); - return OpHandler::handle(grad_tensor_arg, input_tensor_arg, all_dimensions, dim, output_memory_config); - } + const std::optional &memory_config = std::nullopt); }; struct ExecuteUnaryBackwardRecip { @@ -305,6 +298,8 @@ struct ExecuteUnaryBackwardGelu{ }; +DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(Softplus) +DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(Hardtanh) } // operations::unary @@ -569,14 +564,9 @@ constexpr auto clamp_bw = ttnn::register_operation< "ttnn::clamp_bw", operations::unary_backward::ExecuteUnaryBackwardClamp>(); -constexpr auto softplus_bw = ttnn::register_operation< - "ttnn::softplus_bw", - operations::unary_backward::ExecuteUnaryBackwardTwoFloatWithDefault< - operations::unary_backward::UnaryBackwardOpType::SOFTPLUS_BW>>(); -constexpr auto hardtanh_bw = ttnn::register_operation< - "ttnn::hardtanh_bw", - operations::unary_backward::ExecuteUnaryBackwardTwoFloatWithDefault< - operations::unary_backward::UnaryBackwardOpType::HARDTANH_BW>>(); +// Tensor + Float(Default) + Float(Default) +constexpr auto hardtanh_bw = ttnn::register_operation<"ttnn::hardtanh_bw", operations::unary_backward::ExecuteUnaryBackwardHardtanh>(); +constexpr auto softplus_bw = ttnn::register_operation<"ttnn::softplus_bw", operations::unary_backward::ExecuteUnaryBackwardSoftplus>(); constexpr auto rdiv_bw = ttnn::register_operation< "ttnn::rdiv_bw", @@ -612,7 +602,7 @@ constexpr auto silu_bw = ttnn::register_operation< constexpr auto prod_bw = ttnn::register_operation< "ttnn::prod_bw", - operations::unary_backward::ExecuteUnaryBackwardProdBW>(); + operations::unary_backward::ExecuteUnaryBackwardProd>(); constexpr auto relu_bw = ttnn::register_operation< "ttnn::relu_bw",