diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index 7c54278ac02..7192a80a34c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -66,22 +66,14 @@ struct ExecuteUnaryBackwardWoFloat { }; -struct ExecuteUnaryBackwardSoftplus { - static std::vector invoke( - const Tensor &grad_tensor_arg, - const Tensor &input_tensor_arg, - float parameter_a, - float parameter_b, - const std::optional &memory_config = std::nullopt); -}; - -struct ExecuteUnaryBackwardHardtanh { - static std::vector invoke( - const Tensor &grad_tensor_arg, - const Tensor &input_tensor_arg, - float parameter_a, - float parameter_b, - const std::optional &memory_config = std::nullopt); +#define DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(op_name) \ +struct ExecuteUnaryBackward##op_name { \ + static std::vector invoke( \ + const Tensor &grad_tensor_arg, \ + const Tensor &input_tensor_arg, \ + float parameter_a, \ + float parameter_b, \ + const std::optional &memory_config = std::nullopt); \ }; template @@ -306,6 +298,8 @@ struct ExecuteUnaryBackwardGelu{ }; +DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(Softplus) +DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(Hardtanh) } // operations::unary @@ -570,13 +564,9 @@ constexpr auto clamp_bw = ttnn::register_operation< "ttnn::clamp_bw", operations::unary_backward::ExecuteUnaryBackwardClamp>(); -constexpr auto softplus_bw = ttnn::register_operation< - "ttnn::softplus_bw", - operations::unary_backward::ExecuteUnaryBackwardSoftplus>(); - -constexpr auto hardtanh_bw = ttnn::register_operation< - "ttnn::hardtanh_bw", - operations::unary_backward::ExecuteUnaryBackwardHardtanh>(); +// Tensor + Float(Default) + Float(Default) +constexpr auto hardtanh_bw = ttnn::register_operation<"ttnn::hardtanh_bw", operations::unary_backward::ExecuteUnaryBackwardHardtanh>(); +constexpr auto softplus_bw = ttnn::register_operation<"ttnn::softplus_bw", operations::unary_backward::ExecuteUnaryBackwardSoftplus>(); constexpr auto rdiv_bw = ttnn::register_operation< "ttnn::rdiv_bw",